pc-1
This commit is contained in:
dison0331-ThinkPad
2026-03-11 21:32:58 +08:00
commit 8c8ad9fe07
29 changed files with 4005 additions and 0 deletions

View File

@@ -0,0 +1,196 @@
"""
三体问题求解器主模块
"""
import numpy as np
from typing import List, Tuple, Optional, Callable
from .particle import Particle
from .integrator import RK4Integrator
class ThreeBodySolver:
"""三体问题求解器"""
# 万有引力常数 (AU^3 / (M_sun * year^2))
G = 4 * np.pi**2 # 在天文单位制中
def __init__(self, particles: List[Particle], dt: float = 0.001):
"""
初始化三体问题求解器
参数:
particles: 三个质点的列表
dt: 时间步长(年)
"""
if len(particles) != 3:
raise ValueError("三体问题需要恰好3个质点")
self.particles = particles
self.dt = dt
self.integrator = RK4Integrator(dt)
self.time = 0.0
# 验证总动量和角动量守恒(可选)
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
"""
计算所有质点的加速度
参数:
particles: 质点列表
返回:
加速度列表
"""
accelerations = [np.zeros(3) for _ in range(3)]
# 计算每对质点之间的引力
for i in range(3):
for j in range(3):
if i != j:
# 计算相对位置向量
r_vec = particles[j].position - particles[i].position
r = np.linalg.norm(r_vec)
# 避免除以零
if r < 1e-10:
r = 1e-10
# 计算加速度 (F = ma -> a = F/m)
acceleration = self.G * particles[j].mass * r_vec / (r**3)
accelerations[i] += acceleration
return accelerations
def _calculate_total_momentum(self) -> np.ndarray:
"""计算系统总动量"""
total_momentum = np.zeros(3)
for p in self.particles:
total_momentum += p.mass * p.velocity
return total_momentum
def _calculate_total_angular_momentum(self) -> np.ndarray:
"""计算系统总角动量"""
total_angular_momentum = np.zeros(3)
for p in self.particles:
# 角动量 L = r × p = r × (m*v)
angular_momentum = np.cross(p.position, p.mass * p.velocity)
total_angular_momentum += angular_momentum
return total_angular_momentum
def _calculate_total_energy(self) -> float:
"""计算系统总能量(动能+势能)"""
kinetic_energy = 0.0
potential_energy = 0.0
# 计算动能
for p in self.particles:
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
# 计算势能
for i in range(3):
for j in range(i+1, 3):
r_vec = self.particles[j].position - self.particles[i].position
r = np.linalg.norm(r_vec)
if r > 1e-10: # 避免除以零
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
return kinetic_energy + potential_energy
def step(self) -> List[Particle]:
"""
执行一个时间步的积分
返回:
更新后的质点列表
"""
# 计算加速度
acceleration_func = lambda p: self._calculate_accelerations(p)
# 执行积分
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
return self.particles
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
"""
模拟三体运动
参数:
total_time: 总模拟时间(年)
progress_interval: 进度打印间隔步数
返回:
模拟结束后的质点列表
"""
steps = int(total_time / self.dt)
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}")
print(f"初始总能量: {self._calculate_total_energy():.6e}")
acceleration_func = lambda p: self._calculate_accelerations(p)
for step in range(steps):
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
# 打印进度
if step % progress_interval == 0:
energy = self._calculate_total_energy()
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
final_energy = self._calculate_total_energy()
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
return self.particles
def get_trajectories(self) -> List[np.ndarray]:
"""获取所有质点的轨迹"""
return [p.get_trajectory() for p in self.particles]
def get_center_of_mass(self) -> np.ndarray:
"""计算系统质心位置"""
total_mass = sum(p.mass for p in self.particles)
com = np.zeros(3)
for p in self.particles:
com += p.mass * p.position
return com / total_mass if total_mass > 0 else np.zeros(3)
def get_conservation_errors(self) -> Tuple[float, float, float]:
"""
计算守恒定律的误差
返回:
(动量误差, 角动量误差, 能量误差)
"""
# 动量误差
current_momentum = self._calculate_total_momentum()
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
# 角动量误差
current_angular_momentum = self._calculate_total_angular_momentum()
angular_momentum_error = np.linalg.norm(
current_angular_momentum - self.initial_angular_momentum
)
# 能量误差(相对误差)
initial_energy = self._calculate_total_energy() # 重新计算初始能量
current_energy = self._calculate_total_energy()
if abs(initial_energy) > 1e-10:
energy_error = abs((current_energy - initial_energy) / initial_energy)
else:
energy_error = abs(current_energy - initial_energy)
return momentum_error, angular_momentum_error, energy_error
def reset(self):
"""重置求解器状态(清除历史记录)"""
for p in self.particles:
p.position_history.clear()
p.velocity_history.clear()
self.time = 0.0
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()