""" 三体问题求解器主模块 """ import numpy as np from typing import List, Tuple, Optional, Callable from .particle import Particle from .integrator import RK4Integrator class ThreeBodySolver: """三体问题求解器""" # 万有引力常数 (AU^3 / (M_sun * year^2)) G = 4 * np.pi**2 # 在天文单位制中 def __init__(self, particles: List[Particle], dt: float = 0.001): """ 初始化三体问题求解器 参数: particles: 三个质点的列表 dt: 时间步长(年) """ if len(particles) != 3: raise ValueError("三体问题需要恰好3个质点") self.particles = particles self.dt = dt self.integrator = RK4Integrator(dt) self.time = 0.0 # 验证总动量和角动量守恒(可选) self.initial_total_momentum = self._calculate_total_momentum() self.initial_angular_momentum = self._calculate_total_angular_momentum() def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]: """ 计算所有质点的加速度 参数: particles: 质点列表 返回: 加速度列表 """ accelerations = [np.zeros(3) for _ in range(3)] # 计算每对质点之间的引力 for i in range(3): for j in range(3): if i != j: # 计算相对位置向量 r_vec = particles[j].position - particles[i].position r = np.linalg.norm(r_vec) # 避免除以零 if r < 1e-10: r = 1e-10 # 计算加速度 (F = ma -> a = F/m) acceleration = self.G * particles[j].mass * r_vec / (r**3) accelerations[i] += acceleration return accelerations def _calculate_total_momentum(self) -> np.ndarray: """计算系统总动量""" total_momentum = np.zeros(3) for p in self.particles: total_momentum += p.mass * p.velocity return total_momentum def _calculate_total_angular_momentum(self) -> np.ndarray: """计算系统总角动量""" total_angular_momentum = np.zeros(3) for p in self.particles: # 角动量 L = r × p = r × (m*v) angular_momentum = np.cross(p.position, p.mass * p.velocity) total_angular_momentum += angular_momentum return total_angular_momentum def _calculate_total_energy(self) -> float: """计算系统总能量(动能+势能)""" kinetic_energy = 0.0 potential_energy = 0.0 # 计算动能 for p in self.particles: kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2) # 计算势能 for i in range(3): for j in range(i+1, 3): r_vec = self.particles[j].position - self.particles[i].position r = np.linalg.norm(r_vec) if r > 1e-10: # 避免除以零 potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r return kinetic_energy + potential_energy def step(self) -> List[Particle]: """ 执行一个时间步的积分 返回: 更新后的质点列表 """ # 计算加速度 acceleration_func = lambda p: self._calculate_accelerations(p) # 执行积分 self.particles = self.integrator.step(self.particles, acceleration_func) self.time += self.dt return self.particles def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]: """ 模拟三体运动 参数: total_time: 总模拟时间(年) progress_interval: 进度打印间隔步数 返回: 模拟结束后的质点列表 """ steps = int(total_time / self.dt) print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}年") print(f"初始总能量: {self._calculate_total_energy():.6e}") acceleration_func = lambda p: self._calculate_accelerations(p) for step in range(steps): self.particles = self.integrator.step(self.particles, acceleration_func) self.time += self.dt # 打印进度 if step % progress_interval == 0: energy = self._calculate_total_energy() print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}") final_energy = self._calculate_total_energy() print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}") return self.particles def get_trajectories(self) -> List[np.ndarray]: """获取所有质点的轨迹""" return [p.get_trajectory() for p in self.particles] def get_center_of_mass(self) -> np.ndarray: """计算系统质心位置""" total_mass = sum(p.mass for p in self.particles) com = np.zeros(3) for p in self.particles: com += p.mass * p.position return com / total_mass if total_mass > 0 else np.zeros(3) def get_conservation_errors(self) -> Tuple[float, float, float]: """ 计算守恒定律的误差 返回: (动量误差, 角动量误差, 能量误差) """ # 动量误差 current_momentum = self._calculate_total_momentum() momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum) # 角动量误差 current_angular_momentum = self._calculate_total_angular_momentum() angular_momentum_error = np.linalg.norm( current_angular_momentum - self.initial_angular_momentum ) # 能量误差(相对误差) initial_energy = self._calculate_total_energy() # 重新计算初始能量 current_energy = self._calculate_total_energy() if abs(initial_energy) > 1e-10: energy_error = abs((current_energy - initial_energy) / initial_energy) else: energy_error = abs(current_energy - initial_energy) return momentum_error, angular_momentum_error, energy_error def reset(self): """重置求解器状态(清除历史记录)""" for p in self.particles: p.position_history.clear() p.velocity_history.clear() self.time = 0.0 self.initial_total_momentum = self._calculate_total_momentum() self.initial_angular_momentum = self._calculate_total_angular_momentum()