Files
dison0331-ThinkPad 8c8ad9fe07 first
pc-1
2026-03-11 21:32:58 +08:00

196 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
三体问题求解器主模块
"""
import numpy as np
from typing import List, Tuple, Optional, Callable
from .particle import Particle
from .integrator import RK4Integrator
class ThreeBodySolver:
"""三体问题求解器"""
# 万有引力常数 (AU^3 / (M_sun * year^2))
G = 4 * np.pi**2 # 在天文单位制中
def __init__(self, particles: List[Particle], dt: float = 0.001):
"""
初始化三体问题求解器
参数:
particles: 三个质点的列表
dt: 时间步长(年)
"""
if len(particles) != 3:
raise ValueError("三体问题需要恰好3个质点")
self.particles = particles
self.dt = dt
self.integrator = RK4Integrator(dt)
self.time = 0.0
# 验证总动量和角动量守恒(可选)
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
"""
计算所有质点的加速度
参数:
particles: 质点列表
返回:
加速度列表
"""
accelerations = [np.zeros(3) for _ in range(3)]
# 计算每对质点之间的引力
for i in range(3):
for j in range(3):
if i != j:
# 计算相对位置向量
r_vec = particles[j].position - particles[i].position
r = np.linalg.norm(r_vec)
# 避免除以零
if r < 1e-10:
r = 1e-10
# 计算加速度 (F = ma -> a = F/m)
acceleration = self.G * particles[j].mass * r_vec / (r**3)
accelerations[i] += acceleration
return accelerations
def _calculate_total_momentum(self) -> np.ndarray:
"""计算系统总动量"""
total_momentum = np.zeros(3)
for p in self.particles:
total_momentum += p.mass * p.velocity
return total_momentum
def _calculate_total_angular_momentum(self) -> np.ndarray:
"""计算系统总角动量"""
total_angular_momentum = np.zeros(3)
for p in self.particles:
# 角动量 L = r × p = r × (m*v)
angular_momentum = np.cross(p.position, p.mass * p.velocity)
total_angular_momentum += angular_momentum
return total_angular_momentum
def _calculate_total_energy(self) -> float:
"""计算系统总能量(动能+势能)"""
kinetic_energy = 0.0
potential_energy = 0.0
# 计算动能
for p in self.particles:
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
# 计算势能
for i in range(3):
for j in range(i+1, 3):
r_vec = self.particles[j].position - self.particles[i].position
r = np.linalg.norm(r_vec)
if r > 1e-10: # 避免除以零
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
return kinetic_energy + potential_energy
def step(self) -> List[Particle]:
"""
执行一个时间步的积分
返回:
更新后的质点列表
"""
# 计算加速度
acceleration_func = lambda p: self._calculate_accelerations(p)
# 执行积分
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
return self.particles
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
"""
模拟三体运动
参数:
total_time: 总模拟时间(年)
progress_interval: 进度打印间隔步数
返回:
模拟结束后的质点列表
"""
steps = int(total_time / self.dt)
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}")
print(f"初始总能量: {self._calculate_total_energy():.6e}")
acceleration_func = lambda p: self._calculate_accelerations(p)
for step in range(steps):
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
# 打印进度
if step % progress_interval == 0:
energy = self._calculate_total_energy()
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
final_energy = self._calculate_total_energy()
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
return self.particles
def get_trajectories(self) -> List[np.ndarray]:
"""获取所有质点的轨迹"""
return [p.get_trajectory() for p in self.particles]
def get_center_of_mass(self) -> np.ndarray:
"""计算系统质心位置"""
total_mass = sum(p.mass for p in self.particles)
com = np.zeros(3)
for p in self.particles:
com += p.mass * p.position
return com / total_mass if total_mass > 0 else np.zeros(3)
def get_conservation_errors(self) -> Tuple[float, float, float]:
"""
计算守恒定律的误差
返回:
(动量误差, 角动量误差, 能量误差)
"""
# 动量误差
current_momentum = self._calculate_total_momentum()
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
# 角动量误差
current_angular_momentum = self._calculate_total_angular_momentum()
angular_momentum_error = np.linalg.norm(
current_angular_momentum - self.initial_angular_momentum
)
# 能量误差(相对误差)
initial_energy = self._calculate_total_energy() # 重新计算初始能量
current_energy = self._calculate_total_energy()
if abs(initial_energy) > 1e-10:
energy_error = abs((current_energy - initial_energy) / initial_energy)
else:
energy_error = abs(current_energy - initial_energy)
return momentum_error, angular_momentum_error, energy_error
def reset(self):
"""重置求解器状态(清除历史记录)"""
for p in self.particles:
p.position_history.clear()
p.velocity_history.clear()
self.time = 0.0
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()