first
pc-1
This commit is contained in:
196
three_body_problem/solver.py
Normal file
196
three_body_problem/solver.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
三体问题求解器主模块
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Callable
|
||||
from .particle import Particle
|
||||
from .integrator import RK4Integrator
|
||||
|
||||
|
||||
class ThreeBodySolver:
|
||||
"""三体问题求解器"""
|
||||
|
||||
# 万有引力常数 (AU^3 / (M_sun * year^2))
|
||||
G = 4 * np.pi**2 # 在天文单位制中
|
||||
|
||||
def __init__(self, particles: List[Particle], dt: float = 0.001):
|
||||
"""
|
||||
初始化三体问题求解器
|
||||
|
||||
参数:
|
||||
particles: 三个质点的列表
|
||||
dt: 时间步长(年)
|
||||
"""
|
||||
if len(particles) != 3:
|
||||
raise ValueError("三体问题需要恰好3个质点")
|
||||
|
||||
self.particles = particles
|
||||
self.dt = dt
|
||||
self.integrator = RK4Integrator(dt)
|
||||
self.time = 0.0
|
||||
|
||||
# 验证总动量和角动量守恒(可选)
|
||||
self.initial_total_momentum = self._calculate_total_momentum()
|
||||
self.initial_angular_momentum = self._calculate_total_angular_momentum()
|
||||
|
||||
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
|
||||
"""
|
||||
计算所有质点的加速度
|
||||
|
||||
参数:
|
||||
particles: 质点列表
|
||||
|
||||
返回:
|
||||
加速度列表
|
||||
"""
|
||||
accelerations = [np.zeros(3) for _ in range(3)]
|
||||
|
||||
# 计算每对质点之间的引力
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
if i != j:
|
||||
# 计算相对位置向量
|
||||
r_vec = particles[j].position - particles[i].position
|
||||
r = np.linalg.norm(r_vec)
|
||||
|
||||
# 避免除以零
|
||||
if r < 1e-10:
|
||||
r = 1e-10
|
||||
|
||||
# 计算加速度 (F = ma -> a = F/m)
|
||||
acceleration = self.G * particles[j].mass * r_vec / (r**3)
|
||||
accelerations[i] += acceleration
|
||||
|
||||
return accelerations
|
||||
|
||||
def _calculate_total_momentum(self) -> np.ndarray:
|
||||
"""计算系统总动量"""
|
||||
total_momentum = np.zeros(3)
|
||||
for p in self.particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
return total_momentum
|
||||
|
||||
def _calculate_total_angular_momentum(self) -> np.ndarray:
|
||||
"""计算系统总角动量"""
|
||||
total_angular_momentum = np.zeros(3)
|
||||
for p in self.particles:
|
||||
# 角动量 L = r × p = r × (m*v)
|
||||
angular_momentum = np.cross(p.position, p.mass * p.velocity)
|
||||
total_angular_momentum += angular_momentum
|
||||
return total_angular_momentum
|
||||
|
||||
def _calculate_total_energy(self) -> float:
|
||||
"""计算系统总能量(动能+势能)"""
|
||||
kinetic_energy = 0.0
|
||||
potential_energy = 0.0
|
||||
|
||||
# 计算动能
|
||||
for p in self.particles:
|
||||
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
|
||||
|
||||
# 计算势能
|
||||
for i in range(3):
|
||||
for j in range(i+1, 3):
|
||||
r_vec = self.particles[j].position - self.particles[i].position
|
||||
r = np.linalg.norm(r_vec)
|
||||
if r > 1e-10: # 避免除以零
|
||||
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
|
||||
|
||||
return kinetic_energy + potential_energy
|
||||
|
||||
def step(self) -> List[Particle]:
|
||||
"""
|
||||
执行一个时间步的积分
|
||||
|
||||
返回:
|
||||
更新后的质点列表
|
||||
"""
|
||||
# 计算加速度
|
||||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||||
|
||||
# 执行积分
|
||||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||||
self.time += self.dt
|
||||
|
||||
return self.particles
|
||||
|
||||
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
|
||||
"""
|
||||
模拟三体运动
|
||||
|
||||
参数:
|
||||
total_time: 总模拟时间(年)
|
||||
progress_interval: 进度打印间隔步数
|
||||
|
||||
返回:
|
||||
模拟结束后的质点列表
|
||||
"""
|
||||
steps = int(total_time / self.dt)
|
||||
|
||||
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}年")
|
||||
print(f"初始总能量: {self._calculate_total_energy():.6e}")
|
||||
|
||||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||||
|
||||
for step in range(steps):
|
||||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||||
self.time += self.dt
|
||||
|
||||
# 打印进度
|
||||
if step % progress_interval == 0:
|
||||
energy = self._calculate_total_energy()
|
||||
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
|
||||
|
||||
final_energy = self._calculate_total_energy()
|
||||
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
|
||||
|
||||
return self.particles
|
||||
|
||||
def get_trajectories(self) -> List[np.ndarray]:
|
||||
"""获取所有质点的轨迹"""
|
||||
return [p.get_trajectory() for p in self.particles]
|
||||
|
||||
def get_center_of_mass(self) -> np.ndarray:
|
||||
"""计算系统质心位置"""
|
||||
total_mass = sum(p.mass for p in self.particles)
|
||||
com = np.zeros(3)
|
||||
for p in self.particles:
|
||||
com += p.mass * p.position
|
||||
return com / total_mass if total_mass > 0 else np.zeros(3)
|
||||
|
||||
def get_conservation_errors(self) -> Tuple[float, float, float]:
|
||||
"""
|
||||
计算守恒定律的误差
|
||||
|
||||
返回:
|
||||
(动量误差, 角动量误差, 能量误差)
|
||||
"""
|
||||
# 动量误差
|
||||
current_momentum = self._calculate_total_momentum()
|
||||
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
|
||||
|
||||
# 角动量误差
|
||||
current_angular_momentum = self._calculate_total_angular_momentum()
|
||||
angular_momentum_error = np.linalg.norm(
|
||||
current_angular_momentum - self.initial_angular_momentum
|
||||
)
|
||||
|
||||
# 能量误差(相对误差)
|
||||
initial_energy = self._calculate_total_energy() # 重新计算初始能量
|
||||
current_energy = self._calculate_total_energy()
|
||||
if abs(initial_energy) > 1e-10:
|
||||
energy_error = abs((current_energy - initial_energy) / initial_energy)
|
||||
else:
|
||||
energy_error = abs(current_energy - initial_energy)
|
||||
|
||||
return momentum_error, angular_momentum_error, energy_error
|
||||
|
||||
def reset(self):
|
||||
"""重置求解器状态(清除历史记录)"""
|
||||
for p in self.particles:
|
||||
p.position_history.clear()
|
||||
p.velocity_history.clear()
|
||||
self.time = 0.0
|
||||
self.initial_total_momentum = self._calculate_total_momentum()
|
||||
self.initial_angular_momentum = self._calculate_total_angular_momentum()
|
||||
Reference in New Issue
Block a user