196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
"""
|
||
三体问题求解器主模块
|
||
"""
|
||
|
||
import numpy as np
|
||
from typing import List, Tuple, Optional, Callable
|
||
from .particle import Particle
|
||
from .integrator import RK4Integrator
|
||
|
||
|
||
class ThreeBodySolver:
|
||
"""三体问题求解器"""
|
||
|
||
# 万有引力常数 (AU^3 / (M_sun * year^2))
|
||
G = 4 * np.pi**2 # 在天文单位制中
|
||
|
||
def __init__(self, particles: List[Particle], dt: float = 0.001):
|
||
"""
|
||
初始化三体问题求解器
|
||
|
||
参数:
|
||
particles: 三个质点的列表
|
||
dt: 时间步长(年)
|
||
"""
|
||
if len(particles) != 3:
|
||
raise ValueError("三体问题需要恰好3个质点")
|
||
|
||
self.particles = particles
|
||
self.dt = dt
|
||
self.integrator = RK4Integrator(dt)
|
||
self.time = 0.0
|
||
|
||
# 验证总动量和角动量守恒(可选)
|
||
self.initial_total_momentum = self._calculate_total_momentum()
|
||
self.initial_angular_momentum = self._calculate_total_angular_momentum()
|
||
|
||
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
|
||
"""
|
||
计算所有质点的加速度
|
||
|
||
参数:
|
||
particles: 质点列表
|
||
|
||
返回:
|
||
加速度列表
|
||
"""
|
||
accelerations = [np.zeros(3) for _ in range(3)]
|
||
|
||
# 计算每对质点之间的引力
|
||
for i in range(3):
|
||
for j in range(3):
|
||
if i != j:
|
||
# 计算相对位置向量
|
||
r_vec = particles[j].position - particles[i].position
|
||
r = np.linalg.norm(r_vec)
|
||
|
||
# 避免除以零
|
||
if r < 1e-10:
|
||
r = 1e-10
|
||
|
||
# 计算加速度 (F = ma -> a = F/m)
|
||
acceleration = self.G * particles[j].mass * r_vec / (r**3)
|
||
accelerations[i] += acceleration
|
||
|
||
return accelerations
|
||
|
||
def _calculate_total_momentum(self) -> np.ndarray:
|
||
"""计算系统总动量"""
|
||
total_momentum = np.zeros(3)
|
||
for p in self.particles:
|
||
total_momentum += p.mass * p.velocity
|
||
return total_momentum
|
||
|
||
def _calculate_total_angular_momentum(self) -> np.ndarray:
|
||
"""计算系统总角动量"""
|
||
total_angular_momentum = np.zeros(3)
|
||
for p in self.particles:
|
||
# 角动量 L = r × p = r × (m*v)
|
||
angular_momentum = np.cross(p.position, p.mass * p.velocity)
|
||
total_angular_momentum += angular_momentum
|
||
return total_angular_momentum
|
||
|
||
def _calculate_total_energy(self) -> float:
|
||
"""计算系统总能量(动能+势能)"""
|
||
kinetic_energy = 0.0
|
||
potential_energy = 0.0
|
||
|
||
# 计算动能
|
||
for p in self.particles:
|
||
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
|
||
|
||
# 计算势能
|
||
for i in range(3):
|
||
for j in range(i+1, 3):
|
||
r_vec = self.particles[j].position - self.particles[i].position
|
||
r = np.linalg.norm(r_vec)
|
||
if r > 1e-10: # 避免除以零
|
||
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
|
||
|
||
return kinetic_energy + potential_energy
|
||
|
||
def step(self) -> List[Particle]:
|
||
"""
|
||
执行一个时间步的积分
|
||
|
||
返回:
|
||
更新后的质点列表
|
||
"""
|
||
# 计算加速度
|
||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||
|
||
# 执行积分
|
||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||
self.time += self.dt
|
||
|
||
return self.particles
|
||
|
||
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
|
||
"""
|
||
模拟三体运动
|
||
|
||
参数:
|
||
total_time: 总模拟时间(年)
|
||
progress_interval: 进度打印间隔步数
|
||
|
||
返回:
|
||
模拟结束后的质点列表
|
||
"""
|
||
steps = int(total_time / self.dt)
|
||
|
||
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}年")
|
||
print(f"初始总能量: {self._calculate_total_energy():.6e}")
|
||
|
||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||
|
||
for step in range(steps):
|
||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||
self.time += self.dt
|
||
|
||
# 打印进度
|
||
if step % progress_interval == 0:
|
||
energy = self._calculate_total_energy()
|
||
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
|
||
|
||
final_energy = self._calculate_total_energy()
|
||
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
|
||
|
||
return self.particles
|
||
|
||
def get_trajectories(self) -> List[np.ndarray]:
|
||
"""获取所有质点的轨迹"""
|
||
return [p.get_trajectory() for p in self.particles]
|
||
|
||
def get_center_of_mass(self) -> np.ndarray:
|
||
"""计算系统质心位置"""
|
||
total_mass = sum(p.mass for p in self.particles)
|
||
com = np.zeros(3)
|
||
for p in self.particles:
|
||
com += p.mass * p.position
|
||
return com / total_mass if total_mass > 0 else np.zeros(3)
|
||
|
||
def get_conservation_errors(self) -> Tuple[float, float, float]:
|
||
"""
|
||
计算守恒定律的误差
|
||
|
||
返回:
|
||
(动量误差, 角动量误差, 能量误差)
|
||
"""
|
||
# 动量误差
|
||
current_momentum = self._calculate_total_momentum()
|
||
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
|
||
|
||
# 角动量误差
|
||
current_angular_momentum = self._calculate_total_angular_momentum()
|
||
angular_momentum_error = np.linalg.norm(
|
||
current_angular_momentum - self.initial_angular_momentum
|
||
)
|
||
|
||
# 能量误差(相对误差)
|
||
initial_energy = self._calculate_total_energy() # 重新计算初始能量
|
||
current_energy = self._calculate_total_energy()
|
||
if abs(initial_energy) > 1e-10:
|
||
energy_error = abs((current_energy - initial_energy) / initial_energy)
|
||
else:
|
||
energy_error = abs(current_energy - initial_energy)
|
||
|
||
return momentum_error, angular_momentum_error, energy_error
|
||
|
||
def reset(self):
|
||
"""重置求解器状态(清除历史记录)"""
|
||
for p in self.particles:
|
||
p.position_history.clear()
|
||
p.velocity_history.clear()
|
||
self.time = 0.0
|
||
self.initial_total_momentum = self._calculate_total_momentum()
|
||
self.initial_angular_momentum = self._calculate_total_angular_momentum() |