pc-1
This commit is contained in:
dison0331-ThinkPad
2026-03-11 21:32:58 +08:00
commit 8c8ad9fe07
29 changed files with 4005 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
"""
三体问题求解器 - 纯Python方案
一个用于模拟和可视化三体问题的Python库。
使用四阶龙格-库塔法数值求解牛顿引力下的三体运动。
主要组件:
- ThreeBodySolver: 主求解器类
- Particle: 质点类
- RK4Integrator: 数值积分器
- ThreeBodyVisualizer: 可视化工具
- ThreeBodyConfig: 配置管理
示例用法:
>>> from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig
>>> particles = ThreeBodyConfig.create_figure8_config()
>>> solver = ThreeBodySolver(particles, dt=0.001)
>>> solver.simulate(total_time=10.0)
>>> from three_body_problem import ThreeBodyVisualizer
>>> visualizer = ThreeBodyVisualizer()
>>> visualizer.plot_trajectories(solver)
>>> visualizer.show()
"""
from .particle import Particle
from .integrator import RK4Integrator
from .solver import ThreeBodySolver
from .visualizer import ThreeBodyVisualizer
from .config import ThreeBodyConfig
__version__ = "1.0.0"
__author__ = "ThreeBodyProblem Team"
__email__ = "threebody@example.com"
__all__ = [
"Particle",
"RK4Integrator",
"ThreeBodySolver",
"ThreeBodyVisualizer",
"ThreeBodyConfig"
]

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,280 @@
"""
三体问题配置管理模块
"""
import numpy as np
from typing import List, Dict, Any, Optional
from .particle import Particle
class ThreeBodyConfig:
"""三体问题配置类"""
@staticmethod
def create_figure8_config() -> List[Particle]:
"""
创建8字形轨道配置著名的稳定三体轨道
返回:
三个质点的列表
"""
# 8字形轨道的初始条件等质量
m = 1.0 # 质量
# 位置 (Chenciner & Montgomery, 2000)
r1 = np.array([0.97000436, -0.24308753, 0.0])
r2 = np.array([-0.97000436, 0.24308753, 0.0])
r3 = np.array([0.0, 0.0, 0.0])
# 速度
v1 = np.array([0.466203685, 0.43236573, 0.0])
v2 = np.array([0.466203685, 0.43236573, 0.0])
v3 = np.array([-0.93240737, -0.86473146, 0.0])
particles = [
Particle(mass=m, position=r1, velocity=v1, name="Star A", color="red"),
Particle(mass=m, position=r2, velocity=v2, name="Star B", color="green"),
Particle(mass=m, position=r3, velocity=v3, name="Star C", color="blue")
]
return particles
@staticmethod
def create_lagrange_point_config(lagrange_point: int = 4) -> List[Particle]:
"""
创建拉格朗日点配置
参数:
lagrange_point: 拉格朗日点编号 (4=L4, 5=L5)
返回:
三个质点的列表
"""
if lagrange_point not in [4, 5]:
raise ValueError("lagrange_point 必须是 4 (L4) 或 5 (L5)")
# 大质量天体(类似太阳)
m_sun = 1.0
# 小质量天体(类似地球)
m_earth = 3e-6 # 地球质量/太阳质量
# 测试质点(类似小行星)
m_test = 1e-8
# 大质量天体在原点
r_sun = np.array([0.0, 0.0, 0.0])
v_sun = np.array([0.0, 0.0, 0.0])
# 小质量天体在圆形轨道上1 AU距离
r_earth = np.array([1.0, 0.0, 0.0])
# 圆形轨道速度v = sqrt(G*M/r)
v_earth = np.array([0.0, 2*np.pi, 0.0]) # 2π AU/年
# 拉格朗日点位置(等边三角形)
if lagrange_point == 4: # L4
r_test = np.array([0.5, np.sqrt(3)/2, 0.0])
else: # L5
r_test = np.array([0.5, -np.sqrt(3)/2, 0.0])
# 测试质点的速度(与地球相同角速度)
v_test = np.array([-np.sqrt(3)/2 * 2*np.pi, 0.5 * 2*np.pi, 0.0])
particles = [
Particle(mass=m_sun, position=r_sun, velocity=v_sun, name="Sun", color="yellow"),
Particle(mass=m_earth, position=r_earth, velocity=v_earth, name="Earth", color="blue"),
Particle(mass=m_test, position=r_test, velocity=v_test, name="Test", color="gray")
]
return particles
@staticmethod
def create_random_config(masses: Optional[List[float]] = None,
position_range: float = 2.0,
velocity_scale: float = 1.0) -> List[Particle]:
"""
创建随机初始条件配置
参数:
masses: 质量列表如果为None则使用随机质量
position_range: 位置范围±position_range
velocity_scale: 速度缩放因子
返回:
三个质点的列表
"""
if masses is None:
# 随机质量在0.5到2.0之间)
masses = np.random.uniform(0.5, 2.0, 3)
if len(masses) != 3:
raise ValueError("需要恰好3个质量值")
particles = []
colors = ['red', 'green', 'blue']
names = ['Star A', 'Star B', 'Star C']
for i in range(3):
# 随机位置
position = np.random.uniform(-position_range, position_range, 3)
# 随机速度(确保系统总动量接近零)
velocity = np.random.uniform(-velocity_scale, velocity_scale, 3)
particles.append(
Particle(mass=masses[i], position=position, velocity=velocity,
name=names[i], color=colors[i])
)
# 调整速度使系统总动量接近零
total_momentum = sum(p.mass * p.velocity for p in particles)
total_mass = sum(p.mass for p in particles)
for p in particles:
p.velocity -= total_momentum / total_mass
return particles
@staticmethod
def create_binary_star_config() -> List[Particle]:
"""
创建双星系统+测试质点配置
返回:
三个质点的列表
"""
# 双星质量
m1 = 1.0
m2 = 0.8
# 双星位置(在椭圆轨道上)
# 半长轴
a = 1.0
# 偏心率
e = 0.3
# 质心在原点
r1 = np.array([-m2/(m1+m2) * a * (1+e), 0.0, 0.0])
r2 = np.array([m1/(m1+m2) * a * (1+e), 0.0, 0.0])
# 计算轨道速度(简化圆形轨道近似)
# 对于椭圆轨道,速度更复杂,这里使用简化
orbital_speed = np.sqrt(4*np.pi**2 * (m1+m2) / (2*a))
v1 = np.array([0.0, orbital_speed * m2/(m1+m2), 0.0])
v2 = np.array([0.0, -orbital_speed * m1/(m1+m2), 0.0])
# 测试质点(小质量)
m_test = 0.01
r_test = np.array([0.0, 2.0, 0.0])
v_test = np.array([0.5, 0.0, 0.0])
particles = [
Particle(mass=m1, position=r1, velocity=v1, name="Primary", color="red"),
Particle(mass=m2, position=r2, velocity=v2, name="Secondary", color="green"),
Particle(mass=m_test, position=r_test, velocity=v_test, name="Test", color="blue")
]
return particles
@staticmethod
def create_custom_config(config_dict: Dict[str, Any]) -> List[Particle]:
"""
从字典创建自定义配置
参数:
config_dict: 包含配置信息的字典
返回:
三个质点的列表
"""
particles = []
for i in range(3):
key = f"particle_{i+1}"
if key not in config_dict:
raise ValueError(f"配置中缺少 {key}")
p_config = config_dict[key]
particle = Particle(
mass=p_config.get('mass', 1.0),
position=np.array(p_config.get('position', [0.0, 0.0, 0.0])),
velocity=np.array(p_config.get('velocity', [0.0, 0.0, 0.0])),
name=p_config.get('name', f"Particle {i+1}"),
color=p_config.get('color', None)
)
particles.append(particle)
return particles
@staticmethod
def save_config(particles: List[Particle], filename: str):
"""
保存配置到文件
参数:
particles: 质点列表
filename: 文件名
"""
config_dict = {}
for i, p in enumerate(particles):
config_dict[f"particle_{i+1}"] = {
'mass': float(p.mass),
'position': p.position.tolist(),
'velocity': p.velocity.tolist(),
'name': p.name,
'color': p.color
}
import json
with open(filename, 'w') as f:
json.dump(config_dict, f, indent=2)
print(f"配置已保存到: {filename}")
@staticmethod
def load_config(filename: str) -> List[Particle]:
"""
从文件加载配置
参数:
filename: 文件名
返回:
三个质点的列表
"""
import json
with open(filename, 'r') as f:
config_dict = json.load(f)
return ThreeBodyConfig.create_custom_config(config_dict)
@staticmethod
def print_config_summary(particles: List[Particle]):
"""打印配置摘要"""
print("=" * 60)
print("三体问题配置摘要")
print("=" * 60)
total_mass = 0.0
total_momentum = np.zeros(3)
total_angular_momentum = np.zeros(3)
for i, p in enumerate(particles):
print(f"\n质点 {i+1} ({p.name}):")
print(f" 质量: {p.mass:.6f} M_sun")
print(f" 位置: [{p.position[0]:.6f}, {p.position[1]:.6f}, {p.position[2]:.6f}] AU")
print(f" 速度: [{p.velocity[0]:.6f}, {p.velocity[1]:.6f}, {p.velocity[2]:.6f}] AU/yr")
total_mass += p.mass
total_momentum += p.mass * p.velocity
angular_momentum = np.cross(p.position, p.mass * p.velocity)
total_angular_momentum += angular_momentum
print("\n" + "=" * 60)
print("系统总质量: {:.6f} M_sun".format(total_mass))
print("系统总动量: [{:.6e}, {:.6e}, {:.6e}]".format(
total_momentum[0], total_momentum[1], total_momentum[2]))
print("系统总角动量: [{:.6e}, {:.6e}, {:.6e}]".format(
total_angular_momentum[0], total_angular_momentum[1], total_angular_momentum[2]))
print("=" * 60)

299
three_body_problem/demo.py Normal file
View File

@@ -0,0 +1,299 @@
#!/usr/bin/env python3
"""
三体问题求解器演示脚本
"""
import numpy as np
import matplotlib.pyplot as plt
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
def demo_basic_usage():
"""演示基本用法"""
print("=" * 60)
print("三体问题求解器 - 基本用法演示")
print("=" * 60)
# 1. 创建自定义三体系统
print("\n1. 创建自定义三体系统...")
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0], name="Star A", color="red"),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0], name="Star B", color="green"),
Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0], name="Star C", color="blue")
]
# 2. 创建求解器
print("2. 创建求解器...")
solver = ThreeBodySolver(particles, dt=0.001)
# 3. 模拟5年
print("3. 模拟5年运动...")
solver.simulate(total_time=5.0, progress_interval=1000)
# 4. 分析结果
print("\n4. 分析模拟结果...")
trajectories = solver.get_trajectories()
print(f" 轨迹点数: {len(trajectories[0])}")
com = solver.get_center_of_mass()
print(f" 系统质心: [{com[0]:.3f}, {com[1]:.3f}, {com[2]:.3f}] AU")
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f" 动量误差: {momentum_error:.2e}")
print(f" 角动量误差: {angular_momentum_error:.2e}")
print(f" 能量相对误差: {energy_error:.2e}")
# 5. 可视化
print("\n5. 生成可视化图形...")
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
# 创建3D轨迹图
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="自定义三体系统")
# 保存图形
visualizer.save_figure("demo_basic_3d.png", dpi=300)
# 创建2D投影图
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
fig.savefig("demo_basic_2d.png", dpi=300, bbox_inches='tight')
print("\n图形已保存:")
print(" - demo_basic_3d.png (3D轨迹)")
print(" - demo_basic_2d.png (2D投影)")
# 显示图形
visualizer.show()
return solver
def demo_figure8():
"""演示8字形轨道"""
print("\n" + "=" * 60)
print("8字形轨道演示")
print("=" * 60)
# 使用预置配置
particles = ThreeBodyConfig.create_figure8_config()
# 打印配置信息
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟10年
print("\n模拟10年8字形轨道...")
solver.simulate(total_time=10.0, progress_interval=2000)
# 可视化
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="8字形三体轨道")
visualizer.save_figure("demo_figure8.png", dpi=300)
print("\n图形已保存: demo_figure8.png")
visualizer.show()
return solver
def demo_lagrange_points():
"""演示拉格朗日点"""
print("\n" + "=" * 60)
print("拉格朗日点演示")
print("=" * 60)
# 创建L4点配置
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
# 打印配置信息
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟50年
print("\n模拟50年拉格朗日点L4稳定性...")
solver.simulate(total_time=50.0, progress_interval=10000)
# 分析测试质点稳定性
test_particle = solver.particles[2]
trajectory = test_particle.get_trajectory()
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
print(f"\n测试质点稳定性分析:")
print(f" 初始距离L4点: {distances[0]:.6f} AU")
print(f" 最终距离L4点: {distances[-1]:.6f} AU")
print(f" 最大距离偏差: {np.max(distances):.6f} AU")
print(f" 平均距离偏差: {np.mean(distances):.6f} AU")
# 可视化
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="拉格朗日点L4稳定性")
fig.savefig("demo_lagrange_l4.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_lagrange_l4.png")
plt.show()
return solver
def demo_random_systems():
"""演示随机系统"""
print("\n" + "=" * 60)
print("随机三体系统演示")
print("=" * 60)
# 创建3个不同的随机系统
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for i, seed in enumerate([42, 123, 456]):
np.random.seed(seed)
# 创建随机配置
particles = ThreeBodyConfig.create_random_config(
position_range=1.5,
velocity_scale=2.0
)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟5年
solver.simulate(total_time=5.0, progress_interval=2000)
# 绘制轨迹
trajectories = solver.get_trajectories()
colors = ['red', 'green', 'blue']
ax = axes[i]
for j, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[j % len(colors)]
ax.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5)
ax.scatter(traj[-1, 0], traj[-1, 1], color=color, s=50, edgecolors='black', linewidth=1)
ax.set_xlabel('X (AU)', fontsize=10)
ax.set_ylabel('Y (AU)', fontsize=10)
ax.set_title(f'随机系统 #{i+1} (种子: {seed})', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
plt.suptitle('随机三体系统示例', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("demo_random_systems.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_random_systems.png")
plt.show()
def demo_conservation_laws():
"""演示守恒定律"""
print("\n" + "=" * 60)
print("守恒定律验证演示")
print("=" * 60)
# 使用8字形轨道理论上应该守恒
particles = ThreeBodyConfig.create_figure8_config()
# 测试不同时间步长下的守恒性
time_steps = [0.01, 0.005, 0.001, 0.0005]
total_time = 1.0
results = []
for dt in time_steps:
print(f"\n测试时间步长: {dt}")
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=total_time, progress_interval=1000)
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
results.append({
'dt': dt,
'momentum_error': momentum_error,
'angular_momentum_error': angular_momentum_error,
'energy_error': energy_error
})
print(f" 动量误差: {momentum_error:.2e}")
print(f" 角动量误差: {angular_momentum_error:.2e}")
print(f" 能量相对误差: {energy_error:.2e}")
# 绘制误差图
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
dts = [r['dt'] for r in results]
# 动量误差
axes[0].loglog(dts, [r['momentum_error'] for r in results], 'o-', linewidth=2, markersize=8)
axes[0].set_xlabel('时间步长 (年)', fontsize=12)
axes[0].set_ylabel('动量误差', fontsize=12)
axes[0].set_title('动量守恒', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, which='both')
# 角动量误差
axes[1].loglog(dts, [r['angular_momentum_error'] for r in results], 's-', linewidth=2, markersize=8, color='green')
axes[1].set_xlabel('时间步长 (年)', fontsize=12)
axes[1].set_ylabel('角动量误差', fontsize=12)
axes[1].set_title('角动量守恒', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, which='both')
# 能量误差
axes[2].loglog(dts, [r['energy_error'] for r in results], '^-', linewidth=2, markersize=8, color='red')
axes[2].set_xlabel('时间步长 (年)', fontsize=12)
axes[2].set_ylabel('能量相对误差', fontsize=12)
axes[2].set_title('能量守恒', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3, which='both')
# 添加参考线(四阶精度)
ref_dt = np.array(dts)
ref_error = 1e-4 * (ref_dt / 0.001)**4
axes[2].loglog(ref_dt, ref_error, 'k--', alpha=0.7, label='四阶精度参考线')
axes[2].legend()
plt.suptitle('守恒定律数值验证 (RK4方法)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig("demo_conservation_laws.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_conservation_laws.png")
plt.show()
def main():
"""主函数"""
print("三体问题求解器演示")
print("=" * 60)
# 运行所有演示
try:
# 演示1: 基本用法
solver1 = demo_basic_usage()
# 演示2: 8字形轨道
solver2 = demo_figure8()
# 演示3: 拉格朗日点
solver3 = demo_lagrange_points()
# 演示4: 随机系统
demo_random_systems()
# 演示5: 守恒定律
demo_conservation_laws()
print("\n" + "=" * 60)
print("所有演示完成!")
print("=" * 60)
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,16 @@
"""
三体问题示例模块
"""
from .figure8 import run_figure8_example, analyze_figure8_stability
from .lagrange import run_lagrange_example, compare_lagrange_points
from .random import run_random_example, run_multiple_random_simulations
__all__ = [
"run_figure8_example",
"analyze_figure8_stability",
"run_lagrange_example",
"compare_lagrange_points",
"run_random_example",
"run_multiple_random_simulations"
]

View File

@@ -0,0 +1,203 @@
"""
8字形轨道示例 - 著名的稳定三体轨道
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
def run_figure8_example():
"""运行8字形轨道示例"""
print("=" * 60)
print("8字形轨道示例")
print("=" * 60)
# 创建8字形轨道配置
particles = ThreeBodyConfig.create_figure8_config()
# 打印配置摘要
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
dt = 0.001 # 时间步长(年)
solver = ThreeBodySolver(particles, dt=dt)
# 模拟10年
total_time = 10.0
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}")
solver.simulate(total_time=total_time, progress_interval=2000)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f"\n守恒定律误差:")
print(f" 动量误差: {momentum_error:.6e}")
print(f" 角动量误差: {angular_momentum_error:.6e}")
print(f" 能量相对误差: {energy_error:.6e}")
# 可视化
print("\n生成可视化图形...")
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
# 创建3D轨迹图
plt.figure(figsize=(14, 10))
# 3D轨迹
ax1 = plt.subplot(2, 2, 1, projection='3d')
trajectories = solver.get_trajectories()
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax1.plot(traj[:, 0], traj[:, 1], traj[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
ax1.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
com = solver.get_center_of_mass()
ax1.scatter(com[0], com[1], com[2],
color='black', marker='x', s=200, label='质心', linewidth=2)
ax1.set_xlabel('X (AU)')
ax1.set_ylabel('Y (AU)')
ax1.set_zlabel('Z (AU)')
ax1.set_title('8字形轨道 - 3D视图')
ax1.legend()
ax1.grid(True, alpha=0.3)
# XY平面投影
ax2 = plt.subplot(2, 2, 2)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5, label=label)
ax2.scatter(traj[-1, 0], traj[-1, 1], color=color, s=100, edgecolors='black', linewidth=1.5)
ax2.scatter(com[0], com[1], color='black', marker='x', s=200, label='质心', linewidth=2)
ax2.set_xlabel('X (AU)')
ax2.set_ylabel('Y (AU)')
ax2.set_title('XY平面投影')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal', adjustable='box')
# XZ平面投影
ax3 = plt.subplot(2, 2, 3)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax3.plot(traj[:, 0], traj[:, 2], color=color, alpha=0.7, linewidth=1.5, label=label)
ax3.scatter(traj[-1, 0], traj[-1, 2], color=color, s=100, edgecolors='black', linewidth=1.5)
ax3.scatter(com[0], com[2], color='black', marker='x', s=200, label='质心', linewidth=2)
ax3.set_xlabel('X (AU)')
ax3.set_ylabel('Z (AU)')
ax3.set_title('XZ平面投影')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_aspect('equal', adjustable='box')
# YZ平面投影
ax4 = plt.subplot(2, 2, 4)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax4.plot(traj[:, 1], traj[:, 2], color=color, alpha=0.7, linewidth=1.5, label=label)
ax4.scatter(traj[-1, 1], traj[-1, 2], color=color, s=100, edgecolors='black', linewidth=1.5)
ax4.scatter(com[1], com[2], color='black', marker='x', s=200, label='质心', linewidth=2)
ax4.set_xlabel('Y (AU)')
ax4.set_ylabel('Z (AU)')
ax4.set_title('YZ平面投影')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_aspect('equal', adjustable='box')
plt.suptitle('8字形三体轨道', fontsize=16, fontweight='bold')
plt.tight_layout()
# 保存图形
output_file = "figure8_orbit.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n图形已保存到: {output_file}")
# 显示图形
plt.show()
return solver
def analyze_figure8_stability():
"""分析8字形轨道的稳定性"""
print("\n" + "=" * 60)
print("8字形轨道稳定性分析")
print("=" * 60)
# 创建8字形轨道配置
particles = ThreeBodyConfig.create_figure8_config()
# 测试不同时间步长
time_steps = [0.01, 0.005, 0.001, 0.0005]
total_time = 5.0
results = []
for dt in time_steps:
print(f"\n测试时间步长: {dt}")
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=total_time, progress_interval=10000)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
results.append({
'dt': dt,
'momentum_error': momentum_error,
'angular_momentum_error': angular_momentum_error,
'energy_error': energy_error
})
print(f" 能量相对误差: {energy_error:.6e}")
# 绘制误差随步长变化
plt.figure(figsize=(10, 6))
dts = [r['dt'] for r in results]
energy_errors = [r['energy_error'] for r in results]
plt.loglog(dts, energy_errors, 'o-', linewidth=2, markersize=8)
plt.xlabel('时间步长 (年)', fontsize=12)
plt.ylabel('能量相对误差', fontsize=12)
plt.title('8字形轨道数值误差分析', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, which='both')
# 添加参考线(四阶精度)
ref_dt = np.array(dts)
ref_error = 1e-4 * (ref_dt / 0.001)**4
plt.loglog(ref_dt, ref_error, 'r--', alpha=0.7, label='四阶精度参考线')
plt.legend()
plt.tight_layout()
output_file = "figure8_stability_analysis.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n稳定性分析图形已保存到: {output_file}")
plt.show()
if __name__ == "__main__":
# 运行主示例
solver = run_figure8_example()
# 运行稳定性分析(可选)
# analyze_figure8_stability()

View File

@@ -0,0 +1,343 @@
"""
拉格朗日点示例 - 三体问题的稳定点
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
def run_lagrange_example(lagrange_point: int = 4, total_time: float = 100.0):
"""
运行拉格朗日点示例
参数:
lagrange_point: 拉格朗日点编号 (4=L4, 5=L5)
total_time: 总模拟时间(年)
"""
point_name = "L4" if lagrange_point == 4 else "L5"
print("=" * 60)
print(f"拉格朗日点 {point_name} 示例")
print("=" * 60)
# 创建拉格朗日点配置
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=lagrange_point)
# 打印配置摘要
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
dt = 0.001 # 时间步长(年)
solver = ThreeBodySolver(particles, dt=dt)
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}")
print(f"模拟拉格朗日点 {point_name} 的稳定性")
solver.simulate(total_time=total_time, progress_interval=20000)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f"\n守恒定律误差:")
print(f" 动量误差: {momentum_error:.6e}")
print(f" 角动量误差: {angular_momentum_error:.6e}")
print(f" 能量相对误差: {energy_error:.6e}")
# 分析测试质点的轨道稳定性
test_particle = solver.particles[2] # 测试质点是第三个
trajectory = test_particle.get_trajectory()
# 计算与L4/L5点的距离变化
if lagrange_point == 4:
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
else: # L5
lagrange_position = np.array([0.5, -np.sqrt(3)/2, 0.0])
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
time_points = np.arange(len(distances)) * dt
print(f"\n测试质点稳定性分析:")
print(f" 初始距离L{lagrange_point}点: {distances[0]:.6e} AU")
print(f" 最终距离L{lagrange_point}点: {distances[-1]:.6e} AU")
print(f" 最大距离偏差: {np.max(distances):.6e} AU")
print(f" 平均距离偏差: {np.mean(distances):.6e} AU")
# 可视化
print("\n生成可视化图形...")
# 创建图形
fig = plt.figure(figsize=(16, 10))
# 1. XY平面轨迹
ax1 = plt.subplot(2, 3, 1)
# 绘制所有质点的轨迹
trajectories = solver.get_trajectories()
colors = ['gold', 'blue', 'gray']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 只绘制最后一部分轨迹(更清晰)
if len(traj) > 1000:
traj_to_plot = traj[-1000:]
else:
traj_to_plot = traj
ax1.plot(traj_to_plot[:, 0], traj_to_plot[:, 1],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制最终位置
ax1.scatter(traj[-1, 0], traj[-1, 1],
color=color, s=100, edgecolors='black', linewidth=1.5, zorder=5)
# 绘制拉格朗日点位置
ax1.scatter(lagrange_position[0], lagrange_position[1],
color='red', marker='*', s=300, label=f'L{lagrange_point}', zorder=10)
# 绘制等边三角形
triangle_points = [
[0, 0], # 太阳
[1, 0], # 地球
lagrange_position[:2] # L4或L5点
]
triangle_points.append(triangle_points[0]) # 闭合三角形
triangle_points = np.array(triangle_points)
ax1.plot(triangle_points[:, 0], triangle_points[:, 1],
'k--', alpha=0.5, linewidth=1, label='等边三角形')
ax1.set_xlabel('X (AU)', fontsize=12)
ax1.set_ylabel('Y (AU)', fontsize=12)
ax1.set_title(f'拉格朗日点 {point_name} - XY平面', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_aspect('equal', adjustable='box')
# 2. 距离随时间变化
ax2 = plt.subplot(2, 3, 2)
ax2.plot(time_points, distances, 'b-', linewidth=2, alpha=0.8)
ax2.set_xlabel('时间 (年)', fontsize=12)
ax2.set_ylabel(f'距离L{lagrange_point}点 (AU)', fontsize=12)
ax2.set_title('测试质点轨道稳定性', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
# 添加平均距离线
mean_distance = np.mean(distances)
ax2.axhline(y=mean_distance, color='r', linestyle='--', alpha=0.7,
label=f'平均距离: {mean_distance:.3e}')
ax2.legend(fontsize=10)
# 3. 相空间图 (x vs vx)
ax3 = plt.subplot(2, 3, 3)
# 计算速度(使用位置差分)
if len(trajectory) > 1:
dt = solver.dt
velocities = np.gradient(trajectory, dt, axis=0)
x_positions = trajectory[:, 0]
x_velocities = velocities[:, 0]
# 使用颜色表示时间
scatter = ax3.scatter(x_positions, x_velocities, c=time_points,
cmap='viridis', alpha=0.7, s=20)
plt.colorbar(scatter, ax=ax3, label='时间 (年)')
ax3.set_xlabel('X 位置 (AU)', fontsize=12)
ax3.set_ylabel('X 速度 (AU/年)', fontsize=12)
ax3.set_title('测试质点相空间 (X维度)', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)
# 4. 相对位置图(以地球为参考系)
ax4 = plt.subplot(2, 3, 4)
# 计算相对于地球的位置
earth_trajectory = trajectories[1] # 地球是第二个质点
sun_trajectory = trajectories[0] # 太阳是第一个质点
test_trajectory = trajectories[2] # 测试质点是第三个
# 转换为以地球为中心的坐标系
earth_centered_sun = sun_trajectory - earth_trajectory
earth_centered_test = test_trajectory - earth_trajectory
# 只绘制最后一部分
if len(earth_centered_test) > 1000:
earth_centered_test = earth_centered_test[-1000:]
ax4.plot(earth_centered_test[:, 0], earth_centered_test[:, 1],
'gray', alpha=0.7, linewidth=1.5, label='测试质点')
ax4.scatter(0, 0, color='blue', s=200, label='地球', edgecolors='black', linewidth=1.5)
ax4.scatter(earth_centered_sun[-1, 0], earth_centered_sun[-1, 1],
color='gold', s=200, label='太阳', edgecolors='black', linewidth=1.5)
# 绘制理论L4/L5点位置
if lagrange_point == 4:
l_point_relative = np.array([-0.5, np.sqrt(3)/2])
else: # L5
l_point_relative = np.array([-0.5, -np.sqrt(3)/2])
ax4.scatter(l_point_relative[0], l_point_relative[1],
color='red', marker='*', s=300, label=f'L{lagrange_point}', zorder=10)
ax4.set_xlabel('相对X位置 (AU)', fontsize=12)
ax4.set_ylabel('相对Y位置 (AU)', fontsize=12)
ax4.set_title('以地球为参考系', fontsize=14, fontweight='bold')
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)
ax4.set_aspect('equal', adjustable='box')
# 5. 能量随时间变化(简化)
ax5 = plt.subplot(2, 3, 5)
# 计算相对能量变化(简化)
# 在实际实现中,需要记录能量历史
time_array = np.linspace(0, total_time, len(distances))
# 使用距离变化作为能量变化的代理
energy_proxy = distances / distances[0]
ax5.plot(time_array, energy_proxy, 'g-', linewidth=2, alpha=0.8)
ax5.set_xlabel('时间 (年)', fontsize=12)
ax5.set_ylabel('相对能量变化', fontsize=12)
ax5.set_title('轨道能量变化', fontsize=14, fontweight='bold')
ax5.grid(True, alpha=0.3)
ax5.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='初始能量')
ax5.legend(fontsize=10)
# 6. 3D视图
ax6 = plt.subplot(2, 3, 6, projection='3d')
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 只绘制最后一部分轨迹
if len(traj) > 1000:
traj_to_plot = traj[-1000:]
else:
traj_to_plot = traj
ax6.plot(traj_to_plot[:, 0], traj_to_plot[:, 1], traj_to_plot[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制最终位置
ax6.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
color=color, s=100, edgecolors='black', linewidth=1.5, zorder=5)
ax6.scatter(lagrange_position[0], lagrange_position[1], lagrange_position[2],
color='red', marker='*', s=300, label=f'L{lagrange_point}', zorder=10)
ax6.set_xlabel('X (AU)', fontsize=10)
ax6.set_ylabel('Y (AU)', fontsize=10)
ax6.set_zlabel('Z (AU)', fontsize=10)
ax6.set_title('3D视图', fontsize=14, fontweight='bold')
ax6.legend(fontsize=9, loc='upper left')
ax6.grid(True, alpha=0.3)
plt.suptitle(f'拉格朗日点 {point_name} 稳定性分析', fontsize=16, fontweight='bold')
plt.tight_layout()
# 保存图形
output_file = f"lagrange_point_{point_name}.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n图形已保存到: {output_file}")
# 显示图形
plt.show()
return solver
def compare_lagrange_points():
"""比较L4和L5点的稳定性"""
print("\n" + "=" * 60)
print("拉格朗日点L4和L5稳定性比较")
print("=" * 60)
total_time = 50.0
dt = 0.001
results = []
for lagrange_point in [4, 5]:
point_name = f"L{lagrange_point}"
print(f"\n模拟 {point_name} 点...")
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=lagrange_point)
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=total_time, progress_interval=25000)
# 分析测试质点的轨道稳定性
test_particle = solver.particles[2]
trajectory = test_particle.get_trajectory()
if lagrange_point == 4:
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
else: # L5
lagrange_position = np.array([0.5, -np.sqrt(3)/2, 0.0])
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
results.append({
'point': point_name,
'max_distance': np.max(distances),
'mean_distance': np.mean(distances),
'std_distance': np.std(distances),
'final_distance': distances[-1]
})
print(f" {point_name} 最大距离偏差: {np.max(distances):.6e} AU")
print(f" {point_name} 平均距离偏差: {np.mean(distances):.6e} AU")
# 绘制比较图
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
points = [r['point'] for r in results]
max_distances = [r['max_distance'] for r in results]
mean_distances = [r['mean_distance'] for r in results]
x = np.arange(len(points))
width = 0.35
axes[0].bar(x - width/2, max_distances, width, label='最大偏差', color='lightcoral')
axes[0].bar(x + width/2, mean_distances, width, label='平均偏差', color='lightblue')
axes[0].set_xlabel('拉格朗日点', fontsize=12)
axes[0].set_ylabel('距离偏差 (AU)', fontsize=12)
axes[0].set_title('L4和L5点稳定性比较', fontsize=14, fontweight='bold')
axes[0].set_xticks(x)
axes[0].set_xticklabels(points)
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')
# 最终位置偏差
final_distances = [r['final_distance'] for r in results]
axes[1].bar(points, final_distances, color=['lightgreen', 'lightblue'])
axes[1].set_xlabel('拉格朗日点', fontsize=12)
axes[1].set_ylabel('最终距离偏差 (AU)', fontsize=12)
axes[1].set_title('最终位置稳定性', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
output_file = "lagrange_points_comparison.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n比较图形已保存到: {output_file}")
plt.show()
if __name__ == "__main__":
# 运行L4点示例
print("运行拉格朗日点L4示例...")
solver_l4 = run_lagrange_example(lagrange_point=4, total_time=50.0)
# 运行L5点示例
print("\n" + "="*60)
print("运行拉格朗日点L5示例...")
solver_l5 = run_lagrange_example(lagrange_point=5, total_time=50.0)
# 比较L4和L5
compare_lagrange_points()

View File

@@ -0,0 +1,426 @@
"""
随机初始条件示例 - 探索不同的三体系统
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
def run_random_example(seed: int = 42, total_time: float = 20.0):
"""
运行随机初始条件示例
参数:
seed: 随机种子
total_time: 总模拟时间(年)
"""
np.random.seed(seed)
print("=" * 60)
print(f"随机初始条件示例 (种子: {seed})")
print("=" * 60)
# 创建随机配置
particles = ThreeBodyConfig.create_random_config(
masses=None, # 使用随机质量
position_range=2.0, # 位置范围 ±2 AU
velocity_scale=3.0 # 速度缩放因子
)
# 打印配置摘要
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
dt = 0.001 # 时间步长(年)
solver = ThreeBodySolver(particles, dt=dt)
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}")
solver.simulate(total_time=total_time, progress_interval=2000)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f"\n守恒定律误差:")
print(f" 动量误差: {momentum_error:.6e}")
print(f" 角动量误差: {angular_momentum_error:.6e}")
print(f" 能量相对误差: {energy_error:.6e}")
# 分析系统行为
analyze_system_behavior(solver)
# 可视化
print("\n生成可视化图形...")
visualize_random_system(solver, seed)
return solver
def analyze_system_behavior(solver: ThreeBodySolver):
"""分析三体系统的行为"""
print("\n" + "-" * 40)
print("系统行为分析")
print("-" * 40)
trajectories = solver.get_trajectories()
# 计算每个质点的运动范围
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
pos_range = np.ptp(traj, axis=0) # 位置范围 (max - min)
avg_speed = np.mean(np.linalg.norm(np.gradient(traj, solver.dt, axis=0), axis=1))
print(f"\n质点 {i+1} ({particle.name}):")
print(f" 质量: {particle.mass:.4f} M_sun")
print(f" 位置范围: X={pos_range[0]:.3f}, Y={pos_range[1]:.3f}, Z={pos_range[2]:.3f} AU")
print(f" 平均速度: {avg_speed:.3f} AU/年")
# 计算质点之间的最小距离
min_distances = []
for i in range(3):
for j in range(i+1, 3):
traj_i = trajectories[i]
traj_j = trajectories[j]
distances = np.linalg.norm(traj_i - traj_j, axis=1)
min_dist = np.min(distances)
min_distances.append((i, j, min_dist))
print(f"\n质点间最小距离:")
for i, j, min_dist in min_distances:
print(f" 质点{i+1}-质点{j+1}: {min_dist:.4f} AU")
# 检查是否有碰撞或近距离接近
collision_threshold = 0.1 # AU
close_encounters = [(i, j, d) for i, j, d in min_distances if d < collision_threshold]
if close_encounters:
print(f"\n警告: 检测到近距离接近!")
for i, j, d in close_encounters:
print(f" 质点{i+1}和质点{j+1}的最小距离: {d:.4f} AU < {collision_threshold} AU")
else:
print(f"\n系统稳定: 所有质点间距离都大于 {collision_threshold} AU")
# 计算系统质心运动
com_trajectory = []
for t in range(len(trajectories[0])):
com = np.zeros(3)
total_mass = 0.0
for i, traj in enumerate(trajectories):
com += solver.particles[i].mass * traj[t]
total_mass += solver.particles[i].mass
com_trajectory.append(com / total_mass)
com_trajectory = np.array(com_trajectory)
com_range = np.ptp(com_trajectory, axis=0)
print(f"\n系统质心运动范围: X={com_range[0]:.4f}, Y={com_range[1]:.4f}, Z={com_range[2]:.4f} AU")
def visualize_random_system(solver: ThreeBodySolver, seed: int):
"""可视化随机三体系统"""
trajectories = solver.get_trajectories()
# 创建图形
fig = plt.figure(figsize=(16, 12))
# 1. 3D轨迹图
ax1 = plt.subplot(2, 3, 1, projection='3d')
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax1.plot(traj[:, 0], traj[:, 1], traj[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
ax1.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心轨迹
com_trajectory = []
for t in range(len(trajectories[0])):
com = np.zeros(3)
total_mass = 0.0
for i, traj in enumerate(trajectories):
com += solver.particles[i].mass * traj[t]
total_mass += solver.particles[i].mass
com_trajectory.append(com / total_mass)
com_trajectory = np.array(com_trajectory)
ax1.plot(com_trajectory[:, 0], com_trajectory[:, 1], com_trajectory[:, 2],
'k--', alpha=0.5, linewidth=1, label='质心轨迹')
ax1.scatter(com_trajectory[-1, 0], com_trajectory[-1, 1], com_trajectory[-1, 2],
color='black', marker='x', s=200, label='质心', linewidth=2)
ax1.set_xlabel('X (AU)', fontsize=12)
ax1.set_ylabel('Y (AU)', fontsize=12)
ax1.set_zlabel('Z (AU)', fontsize=12)
ax1.set_title('3D轨迹图', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
# 2. XY平面投影
ax2 = plt.subplot(2, 3, 2)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5, label=label)
ax2.scatter(traj[-1, 0], traj[-1, 1], color=color, s=100, edgecolors='black', linewidth=1.5)
ax2.plot(com_trajectory[:, 0], com_trajectory[:, 1], 'k--', alpha=0.5, linewidth=1)
ax2.scatter(com_trajectory[-1, 0], com_trajectory[-1, 1],
color='black', marker='x', s=200, linewidth=2)
ax2.set_xlabel('X (AU)', fontsize=12)
ax2.set_ylabel('Y (AU)', fontsize=12)
ax2.set_title('XY平面投影', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal', adjustable='box')
# 3. 距离随时间变化
ax3 = plt.subplot(2, 3, 3)
time_points = np.arange(len(trajectories[0])) * solver.dt
# 计算所有质点对之间的距离
distances = []
labels = []
for i in range(3):
for j in range(i+1, 3):
dist = np.linalg.norm(trajectories[i] - trajectories[j], axis=1)
distances.append(dist)
labels.append(f"质点{i+1}-质点{j+1}")
for dist, label in zip(distances, labels):
ax3.plot(time_points, dist, linewidth=1.5, alpha=0.8, label=label)
ax3.set_xlabel('时间 (年)', fontsize=12)
ax3.set_ylabel('距离 (AU)', fontsize=12)
ax3.set_title('质点间距离变化', fontsize=14, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
# 4. 速度大小随时间变化
ax4 = plt.subplot(2, 3, 4)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 计算速度大小(使用位置差分)
if len(traj) > 1:
velocities = np.gradient(traj, solver.dt, axis=0)
speed = np.linalg.norm(velocities, axis=1)
ax4.plot(time_points, speed, color=color, linewidth=1.5, alpha=0.8, label=label)
ax4.set_xlabel('时间 (年)', fontsize=12)
ax4.set_ylabel('速度大小 (AU/年)', fontsize=12)
ax4.set_title('质点速度变化', fontsize=14, fontweight='bold')
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)
# 5. 能量分布饼图
ax5 = plt.subplot(2, 3, 5)
# 计算每个质点的动能和势能
kinetic_energies = []
potential_energies = []
for i, particle in enumerate(solver.particles):
# 动能
v_squared = np.sum(particle.velocity**2)
kinetic_energy = 0.5 * particle.mass * v_squared
kinetic_energies.append(kinetic_energy)
# 势能(与其他质点的相互作用)
potential_energy = 0.0
for j, other in enumerate(solver.particles):
if i != j:
r_vec = other.position - particle.position
r = np.linalg.norm(r_vec)
if r > 1e-10:
potential_energy -= ThreeBodySolver.G * particle.mass * other.mass / r
potential_energies.append(potential_energy)
# 只考虑势能的一半(每对质点计算了两次)
potential_energies = [pe/2 for pe in potential_energies]
labels = [f"质点{i+1}" for i in range(3)]
colors_pie = ['lightcoral', 'lightgreen', 'lightblue']
ax5.pie(kinetic_energies, labels=labels, autopct='%1.1f%%',
colors=colors_pie, startangle=90)
ax5.set_title('动能分布', fontsize=14, fontweight='bold')
# 6. 相空间图(所有质点)
ax6 = plt.subplot(2, 3, 6)
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
if len(traj) > 1:
velocities = np.gradient(traj, solver.dt, axis=0)
x_positions = traj[:, 0]
x_velocities = velocities[:, 0]
# 使用颜色表示时间
scatter = ax6.scatter(x_positions, x_velocities, c=time_points,
cmap='viridis', alpha=0.6, s=10, label=label)
plt.colorbar(scatter, ax=ax6, label='时间 (年)')
ax6.set_xlabel('X 位置 (AU)', fontsize=12)
ax6.set_ylabel('X 速度 (AU/年)', fontsize=12)
ax6.set_title('相空间图 (X维度)', fontsize=14, fontweight='bold')
ax6.legend(fontsize=10)
ax6.grid(True, alpha=0.3)
plt.suptitle(f'随机三体系统 (种子: {seed})', fontsize=16, fontweight='bold')
plt.tight_layout()
# 保存图形
output_file = f"random_system_seed_{seed}.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n图形已保存到: {output_file}")
# 显示图形
plt.show()
def run_multiple_random_simulations(n_simulations: int = 5, total_time: float = 10.0):
"""运行多个随机模拟并比较结果"""
print("=" * 60)
print(f"运行 {n_simulations} 个随机三体系统模拟")
print("=" * 60)
results = []
for sim_idx in range(n_simulations):
seed = 100 + sim_idx # 不同的随机种子
np.random.seed(seed)
print(f"\n模拟 {sim_idx+1}/{n_simulations} (种子: {seed})")
# 创建随机配置
particles = ThreeBodyConfig.create_random_config(
masses=None,
position_range=2.0,
velocity_scale=2.0 + np.random.random() * 2.0 # 随机速度缩放
)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
solver.simulate(total_time=total_time, progress_interval=5000)
# 分析结果
trajectories = solver.get_trajectories()
# 计算系统特性
final_distances = []
for i in range(3):
for j in range(i+1, 3):
final_dist = np.linalg.norm(trajectories[i][-1] - trajectories[j][-1])
final_distances.append(final_dist)
avg_final_distance = np.mean(final_distances)
std_final_distance = np.std(final_distances)
# 计算质心移动距离
initial_com = solver.get_center_of_mass()
# 需要重新计算初始质心
total_mass = sum(p.mass for p in particles)
initial_com = np.zeros(3)
for p in particles:
initial_com += p.mass * p.position
initial_com /= total_mass
final_com = solver.get_center_of_mass()
com_movement = np.linalg.norm(final_com - initial_com)
results.append({
'seed': seed,
'avg_final_distance': avg_final_distance,
'std_final_distance': std_final_distance,
'com_movement': com_movement,
'energy_error': solver.get_conservation_errors()[2]
})
print(f" 平均最终距离: {avg_final_distance:.3f} AU")
print(f" 质心移动: {com_movement:.3f} AU")
print(f" 能量相对误差: {solver.get_conservation_errors()[2]:.6e}")
# 绘制比较图
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
seeds = [r['seed'] for r in results]
avg_distances = [r['avg_final_distance'] for r in results]
com_movements = [r['com_movement'] for r in results]
energy_errors = [r['energy_error'] for r in results]
# 平均最终距离
axes[0, 0].bar(range(n_simulations), avg_distances, color='skyblue', edgecolor='black')
axes[0, 0].set_xlabel('模拟编号', fontsize=12)
axes[0, 0].set_ylabel('平均最终距离 (AU)', fontsize=12)
axes[0, 0].set_title('质点间平均距离', fontsize=14, fontweight='bold')
axes[0, 0].set_xticks(range(n_simulations))
axes[0, 0].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
axes[0, 0].grid(True, alpha=0.3, axis='y')
# 质心移动
axes[0, 1].bar(range(n_simulations), com_movements, color='lightgreen', edgecolor='black')
axes[0, 1].set_xlabel('模拟编号', fontsize=12)
axes[0, 1].set_ylabel('质心移动距离 (AU)', fontsize=12)
axes[0, 1].set_title('系统质心移动', fontsize=14, fontweight='bold')
axes[0, 1].set_xticks(range(n_simulations))
axes[0, 1].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
axes[0, 1].grid(True, alpha=0.3, axis='y')
# 能量误差
axes[1, 0].bar(range(n_simulations), energy_errors, color='lightcoral', edgecolor='black')
axes[1, 0].set_xlabel('模拟编号', fontsize=12)
axes[1, 0].set_ylabel('能量相对误差', fontsize=12)
axes[1, 0].set_title('数值积分误差', fontsize=14, fontweight='bold')
axes[1, 0].set_xticks(range(n_simulations))
axes[1, 0].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3, axis='y')
# 散点图:质心移动 vs 平均距离
axes[1, 1].scatter(avg_distances, com_movements, s=100, c=range(n_simulations),
cmap='viridis', edgecolors='black', alpha=0.8)
axes[1, 1].set_xlabel('平均最终距离 (AU)', fontsize=12)
axes[1, 1].set_ylabel('质心移动距离 (AU)', fontsize=12)
axes[1, 1].set_title('系统稳定性关系', fontsize=14, fontweight='bold')
# 添加标签
for i, (x, y) in enumerate(zip(avg_distances, com_movements)):
axes[1, 1].annotate(f"#{i+1}", (x, y), textcoords="offset points",
xytext=(0, 10), ha='center', fontsize=9)
axes[1, 1].grid(True, alpha=0.3)
plt.suptitle(f'{n_simulations}个随机三体系统模拟比较', fontsize=16, fontweight='bold')
plt.tight_layout()
output_file = "multiple_random_simulations.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n比较图形已保存到: {output_file}")
plt.show()
return results
if __name__ == "__main__":
# 运行单个随机示例
print("运行单个随机三体系统示例...")
solver = run_random_example(seed=42, total_time=15.0)
# 运行多个随机模拟(可选)
# print("\n" + "="*60)
# print("运行多个随机模拟比较...")
# results = run_multiple_random_simulations(n_simulations=5, total_time=5.0)

View File

@@ -0,0 +1,119 @@
"""
数值积分器模块提供RK4方法求解微分方程
"""
import numpy as np
from typing import Callable, Tuple, List
from .particle import Particle
class RK4Integrator:
"""四阶龙格-库塔法积分器"""
def __init__(self, dt: float = 0.001):
"""
初始化积分器
参数:
dt: 时间步长(年)
"""
self.dt = dt
def step(self, particles: List[Particle],
acceleration_func: Callable[[List[Particle]], List[np.ndarray]]) -> List[Particle]:
"""
执行一个时间步的积分
参数:
particles: 质点列表
acceleration_func: 计算加速度的函数
返回:
更新后的质点列表
"""
# 保存当前状态
current_positions = [p.position.copy() for p in particles]
current_velocities = [p.velocity.copy() for p in particles]
# RK4步骤1: k1
k1_v = acceleration_func(particles)
k1_r = current_velocities
# 临时更新位置和速度用于k2计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + 0.5 * self.dt * k1_r[i]
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k1_v[i]
temp_particles.append(temp_p)
# RK4步骤2: k2
k2_v = acceleration_func(temp_particles)
k2_r = [p.velocity for p in temp_particles]
# 临时更新位置和速度用于k3计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + 0.5 * self.dt * k2_r[i]
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k2_v[i]
temp_particles.append(temp_p)
# RK4步骤3: k3
k3_v = acceleration_func(temp_particles)
k3_r = [p.velocity for p in temp_particles]
# 临时更新位置和速度用于k4计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + self.dt * k3_r[i]
temp_p.velocity = current_velocities[i] + self.dt * k3_v[i]
temp_particles.append(temp_p)
# RK4步骤4: k4
k4_v = acceleration_func(temp_particles)
k4_r = [p.velocity for p in temp_particles]
# 组合所有k值计算最终更新
new_particles = []
for i, p in enumerate(particles):
# 计算新的速度
new_velocity = (current_velocities[i] +
self.dt / 6.0 * (k1_v[i] + 2*k2_v[i] + 2*k3_v[i] + k4_v[i]))
# 计算新的位置
new_position = (current_positions[i] +
self.dt / 6.0 * (k1_r[i] + 2*k2_r[i] + 2*k3_r[i] + k4_r[i]))
# 创建新的质点对象
new_p = p.copy()
new_p.update(new_position, new_velocity)
new_particles.append(new_p)
return new_particles
def integrate(self, particles: List[Particle],
acceleration_func: Callable[[List[Particle]], List[np.ndarray]],
steps: int) -> List[Particle]:
"""
执行多步积分
参数:
particles: 初始质点列表
acceleration_func: 计算加速度的函数
steps: 积分步数
返回:
积分结束后的质点列表
"""
current_particles = particles
for step in range(steps):
current_particles = self.step(current_particles, acceleration_func)
# 每1000步打印进度
if step % 1000 == 0:
print(f"积分进度: {step}/{steps}")
return current_particles

View File

@@ -0,0 +1,64 @@
"""
质点类,表示三体问题中的一个天体
"""
import numpy as np
from typing import Tuple, List
class Particle:
"""表示一个质点的类"""
def __init__(self, mass: float, position: np.ndarray, velocity: np.ndarray,
name: str = "", color: str = None):
"""
初始化质点
参数:
mass: 质量(太阳质量)
position: 位置向量 (x, y, z) [AU]
velocity: 速度向量 (vx, vy, vz) [AU/年]
name: 质点名称
color: 可视化颜色
"""
self.mass = mass
self.position = np.array(position, dtype=np.float64)
self.velocity = np.array(velocity, dtype=np.float64)
self.name = name
self.color = color
# 存储轨迹历史
self.position_history = []
self.velocity_history = []
def update(self, new_position: np.ndarray, new_velocity: np.ndarray):
"""更新质点的状态并记录历史"""
self.position_history.append(self.position.copy())
self.velocity_history.append(self.velocity.copy())
self.position = new_position
self.velocity = new_velocity
def get_trajectory(self) -> np.ndarray:
"""获取轨迹历史"""
if not self.position_history:
return np.array([self.position])
return np.array(self.position_history + [self.position])
def get_energy(self) -> float:
"""计算质点的动能"""
v_squared = np.sum(self.velocity**2)
return 0.5 * self.mass * v_squared
def __repr__(self) -> str:
return (f"Particle(name='{self.name}', mass={self.mass:.3f}, "
f"position={self.position}, velocity={self.velocity})")
def copy(self) -> 'Particle':
"""创建质点的副本"""
return Particle(
mass=self.mass,
position=self.position.copy(),
velocity=self.velocity.copy(),
name=self.name,
color=self.color
)

View File

@@ -0,0 +1,6 @@
# 三体问题求解器依赖
numpy>=1.21.0
matplotlib>=3.5.0
# 测试依赖(可选)
pytest>=7.0.0

View File

@@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
三体问题求解器 - 快速示例
运行此文件以查看三体问题的基本模拟
"""
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
# 添加当前目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
def simple_example():
"""简单示例:自定义三体系统"""
print("=" * 60)
print("三体问题求解器 - 简单示例")
print("=" * 60)
# 创建三个质点
print("\n1. 创建三个质点...")
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0], name="Star A", color="red"),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0], name="Star B", color="green"),
Particle(mass=0.5, position=[0.0, 1.0, 0.0], velocity=[-0.3, 0.0, 0.0], name="Star C", color="blue")
]
for i, p in enumerate(particles):
print(f" 质点 {i+1}: {p.name}, 质量: {p.mass:.2f}, 位置: [{p.position[0]:.2f}, {p.position[1]:.2f}, {p.position[2]:.2f}]")
# 创建求解器
print("\n2. 创建求解器...")
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟2年
print("\n3. 模拟2年运动...")
print(" 开始积分...")
solver.simulate(total_time=2.0, progress_interval=500)
# 分析结果
print("\n4. 分析结果...")
trajectories = solver.get_trajectories()
print(f" 模拟完成! 生成了 {len(trajectories[0])} 个轨迹点")
com = solver.get_center_of_mass()
print(f" 系统质心位置: [{com[0]:.4f}, {com[1]:.4f}, {com[2]:.4f}] AU")
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f" 守恒定律误差:")
print(f" - 动量误差: {momentum_error:.2e}")
print(f" - 角动量误差: {angular_momentum_error:.2e}")
print(f" - 能量相对误差: {energy_error:.2e}")
# 可视化
print("\n5. 生成可视化图形...")
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
# 创建3D轨迹图
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="三体系统运动轨迹")
# 保存图形
output_file = "simple_example_3d.png"
visualizer.save_figure(output_file, dpi=200)
print(f" 3D轨迹图已保存到: {output_file}")
# 创建2D投影图
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
output_file_2d = "simple_example_2d.png"
fig.savefig(output_file_2d, dpi=200, bbox_inches='tight')
print(f" 2D投影图已保存到: {output_file_2d}")
print("\n6. 显示图形...")
plt.show()
return solver
def figure8_example():
"""8字形轨道示例"""
print("\n" + "=" * 60)
print("8字形轨道示例")
print("=" * 60)
# 使用预置的8字形轨道配置
particles = ThreeBodyConfig.create_figure8_config()
# 打印配置信息
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟5年
print("\n模拟5年8字形轨道...")
solver.simulate(total_time=5.0, progress_interval=1000)
# 可视化
visualizer = ThreeBodyVisualizer(figsize=(10, 8))
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="8字形三体轨道")
output_file = "figure8_example.png"
visualizer.save_figure(output_file, dpi=200)
print(f"\n图形已保存到: {output_file}")
plt.show()
return solver
def quick_test():
"""快速测试所有模块"""
print("=" * 60)
print("快速测试所有模块")
print("=" * 60)
# 测试1: 创建质点
print("\n1. 测试质点创建...")
p1 = Particle(mass=1.0, position=[1, 0, 0], velocity=[0, 1, 0], name="Test Star")
print(f" 创建质点: {p1.name}, 质量: {p1.mass}, 位置: {p1.position}")
# 测试2: 创建配置
print("\n2. 测试配置创建...")
particles = ThreeBodyConfig.create_figure8_config()
print(f" 创建了 {len(particles)} 个质点")
# 测试3: 创建求解器
print("\n3. 测试求解器创建...")
solver = ThreeBodySolver(particles, dt=0.01)
print(f" 求解器创建成功,时间步长: {solver.dt}")
# 测试4: 单步积分
print("\n4. 测试单步积分...")
initial_positions = [p.position.copy() for p in particles]
solver.step()
for i, (old_pos, particle) in enumerate(zip(initial_positions, solver.particles)):
moved = not np.allclose(old_pos, particle.position)
print(f" 质点 {i+1} 位置变化: {'' if moved else ''}")
# 测试5: 质心计算
print("\n5. 测试质心计算...")
com = solver.get_center_of_mass()
print(f" 系统质心: [{com[0]:.4f}, {com[1]:.4f}, {com[2]:.4f}]")
# 测试6: 能量计算
print("\n6. 测试能量计算...")
energy = solver._calculate_total_energy()
print(f" 系统总能量: {energy:.6e}")
print("\n所有测试通过!")
return True
def main():
"""主函数"""
print("三体问题求解器示例")
print("=" * 60)
print("选择要运行的示例:")
print(" 1. 简单示例 (自定义三体系统)")
print(" 2. 8字形轨道示例")
print(" 3. 快速测试所有模块")
print(" 4. 全部运行")
try:
choice = input("\n请输入选择 (1-4): ").strip()
if choice == "1":
simple_example()
elif choice == "2":
figure8_example()
elif choice == "3":
quick_test()
elif choice == "4":
quick_test()
simple_example()
figure8_example()
else:
print("无效选择,运行简单示例...")
simple_example()
except KeyboardInterrupt:
print("\n\n用户中断")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("示例运行完成!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
with open("requirements.txt", "r", encoding="utf-8") as fh:
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
setup(
name="three-body-problem",
version="1.0.0",
author="ThreeBodyProblem Team",
author_email="threebody@example.com",
description="A pure Python solver for the three-body problem",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/three-body-problem",
packages=find_packages(),
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Astronomy",
"Topic :: Scientific/Engineering :: Physics",
],
python_requires=">=3.7",
install_requires=requirements,
extras_require={
"dev": [
"pytest>=7.0.0",
"black>=22.0.0",
"flake8>=5.0.0",
],
"docs": [
"sphinx>=5.0.0",
"sphinx-rtd-theme>=1.0.0",
],
},
entry_points={
"console_scripts": [
"three-body-demo=three_body_problem.demo:main",
],
},
)

View File

@@ -0,0 +1,196 @@
"""
三体问题求解器主模块
"""
import numpy as np
from typing import List, Tuple, Optional, Callable
from .particle import Particle
from .integrator import RK4Integrator
class ThreeBodySolver:
"""三体问题求解器"""
# 万有引力常数 (AU^3 / (M_sun * year^2))
G = 4 * np.pi**2 # 在天文单位制中
def __init__(self, particles: List[Particle], dt: float = 0.001):
"""
初始化三体问题求解器
参数:
particles: 三个质点的列表
dt: 时间步长(年)
"""
if len(particles) != 3:
raise ValueError("三体问题需要恰好3个质点")
self.particles = particles
self.dt = dt
self.integrator = RK4Integrator(dt)
self.time = 0.0
# 验证总动量和角动量守恒(可选)
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
"""
计算所有质点的加速度
参数:
particles: 质点列表
返回:
加速度列表
"""
accelerations = [np.zeros(3) for _ in range(3)]
# 计算每对质点之间的引力
for i in range(3):
for j in range(3):
if i != j:
# 计算相对位置向量
r_vec = particles[j].position - particles[i].position
r = np.linalg.norm(r_vec)
# 避免除以零
if r < 1e-10:
r = 1e-10
# 计算加速度 (F = ma -> a = F/m)
acceleration = self.G * particles[j].mass * r_vec / (r**3)
accelerations[i] += acceleration
return accelerations
def _calculate_total_momentum(self) -> np.ndarray:
"""计算系统总动量"""
total_momentum = np.zeros(3)
for p in self.particles:
total_momentum += p.mass * p.velocity
return total_momentum
def _calculate_total_angular_momentum(self) -> np.ndarray:
"""计算系统总角动量"""
total_angular_momentum = np.zeros(3)
for p in self.particles:
# 角动量 L = r × p = r × (m*v)
angular_momentum = np.cross(p.position, p.mass * p.velocity)
total_angular_momentum += angular_momentum
return total_angular_momentum
def _calculate_total_energy(self) -> float:
"""计算系统总能量(动能+势能)"""
kinetic_energy = 0.0
potential_energy = 0.0
# 计算动能
for p in self.particles:
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
# 计算势能
for i in range(3):
for j in range(i+1, 3):
r_vec = self.particles[j].position - self.particles[i].position
r = np.linalg.norm(r_vec)
if r > 1e-10: # 避免除以零
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
return kinetic_energy + potential_energy
def step(self) -> List[Particle]:
"""
执行一个时间步的积分
返回:
更新后的质点列表
"""
# 计算加速度
acceleration_func = lambda p: self._calculate_accelerations(p)
# 执行积分
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
return self.particles
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
"""
模拟三体运动
参数:
total_time: 总模拟时间(年)
progress_interval: 进度打印间隔步数
返回:
模拟结束后的质点列表
"""
steps = int(total_time / self.dt)
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}")
print(f"初始总能量: {self._calculate_total_energy():.6e}")
acceleration_func = lambda p: self._calculate_accelerations(p)
for step in range(steps):
self.particles = self.integrator.step(self.particles, acceleration_func)
self.time += self.dt
# 打印进度
if step % progress_interval == 0:
energy = self._calculate_total_energy()
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
final_energy = self._calculate_total_energy()
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
return self.particles
def get_trajectories(self) -> List[np.ndarray]:
"""获取所有质点的轨迹"""
return [p.get_trajectory() for p in self.particles]
def get_center_of_mass(self) -> np.ndarray:
"""计算系统质心位置"""
total_mass = sum(p.mass for p in self.particles)
com = np.zeros(3)
for p in self.particles:
com += p.mass * p.position
return com / total_mass if total_mass > 0 else np.zeros(3)
def get_conservation_errors(self) -> Tuple[float, float, float]:
"""
计算守恒定律的误差
返回:
(动量误差, 角动量误差, 能量误差)
"""
# 动量误差
current_momentum = self._calculate_total_momentum()
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
# 角动量误差
current_angular_momentum = self._calculate_total_angular_momentum()
angular_momentum_error = np.linalg.norm(
current_angular_momentum - self.initial_angular_momentum
)
# 能量误差(相对误差)
initial_energy = self._calculate_total_energy() # 重新计算初始能量
current_energy = self._calculate_total_energy()
if abs(initial_energy) > 1e-10:
energy_error = abs((current_energy - initial_energy) / initial_energy)
else:
energy_error = abs(current_energy - initial_energy)
return momentum_error, angular_momentum_error, energy_error
def reset(self):
"""重置求解器状态(清除历史记录)"""
for p in self.particles:
p.position_history.clear()
p.velocity_history.clear()
self.time = 0.0
self.initial_total_momentum = self._calculate_total_momentum()
self.initial_angular_momentum = self._calculate_total_angular_momentum()

View File

@@ -0,0 +1,3 @@
"""
三体问题测试模块
"""

View File

@@ -0,0 +1,385 @@
"""
三体问题求解器测试
"""
import numpy as np
import pytest
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig
class TestParticle:
"""测试质点类"""
def test_particle_creation(self):
"""测试质点创建"""
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
assert p.mass == 1.0
assert np.allclose(p.position, [1.0, 2.0, 3.0])
assert np.allclose(p.velocity, [0.1, 0.2, 0.3])
assert p.name == "Test"
assert len(p.position_history) == 0
assert len(p.velocity_history) == 0
def test_particle_update(self):
"""测试质点状态更新"""
p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
new_position = np.array([1.0, 2.0, 3.0])
new_velocity = np.array([0.1, 0.2, 0.3])
p.update(new_position, new_velocity)
assert np.allclose(p.position, new_position)
assert np.allclose(p.velocity, new_velocity)
assert len(p.position_history) == 1
assert len(p.velocity_history) == 1
assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0])
assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0])
def test_particle_energy(self):
"""测试质点动能计算"""
p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0])
# 动能 = 0.5 * m * v^2
expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2)
calculated_energy = p.get_energy()
assert np.isclose(calculated_energy, expected_energy)
def test_particle_copy(self):
"""测试质点复制"""
p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original")
p2 = p1.copy()
# 检查值相等
assert p2.mass == p1.mass
assert np.allclose(p2.position, p1.position)
assert np.allclose(p2.velocity, p1.velocity)
assert p2.name == p1.name
# 检查是深拷贝
p2.position[0] = 100.0
assert p1.position[0] == 1.0 # 原始对象不应改变
class TestThreeBodySolver:
"""测试三体问题求解器"""
def test_solver_creation(self):
"""测试求解器创建"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
assert len(solver.particles) == 3
assert solver.dt == 0.001
assert solver.time == 0.0
def test_solver_wrong_number_of_particles(self):
"""测试错误数量的质点"""
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
with pytest.raises(ValueError, match="三体问题需要恰好3个质点"):
ThreeBodySolver(particles, dt=0.001)
def test_acceleration_calculation(self):
"""测试加速度计算"""
# 创建简单的测试配置:三个质点在等边三角形顶点
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 计算加速度
accelerations = solver._calculate_accelerations(particles)
# 检查加速度形状
assert len(accelerations) == 3
for acc in accelerations:
assert acc.shape == (3,)
# 由于对称性,第一个质点的加速度应该指向其他两个质点的方向
# 这里只检查计算是否成功,不检查具体值
def test_center_of_mass(self):
"""测试质心计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
com = solver.get_center_of_mass()
# 计算期望的质心
total_mass = 1.0 + 2.0 + 3.0
expected_com = np.array([
(1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass
])
assert np.allclose(com, expected_com)
def test_energy_calculation(self):
"""测试能量计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
energy = solver._calculate_total_energy()
# 能量应该是有限的实数
assert np.isfinite(energy)
assert isinstance(energy, float)
def test_single_step(self):
"""测试单步积分"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 保存初始位置
initial_positions = [p.position.copy() for p in particles]
# 执行单步积分
new_particles = solver.step()
# 检查位置是否改变
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
assert not np.allclose(old_pos, new_particle.position)
# 检查时间是否增加
assert solver.time == 0.001
def test_conservation_laws(self):
"""测试守恒定律(动量、角动量、能量)"""
# 使用8字形轨道配置应该是稳定的
particles = ThreeBodyConfig.create_figure8_config()
solver = ThreeBodySolver(particles, dt=0.0001)
# 模拟很短时间
solver.simulate(total_time=0.1, progress_interval=100)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
# 误差应该很小
assert momentum_error < 1e-10
assert angular_momentum_error < 1e-10
assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差
class TestThreeBodyConfig:
"""测试配置管理"""
def test_figure8_config(self):
"""测试8字形轨道配置"""
particles = ThreeBodyConfig.create_figure8_config()
assert len(particles) == 3
assert all(p.mass == 1.0 for p in particles)
# 检查总动量接近零8字形轨道应该满足
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_lagrange_config(self):
"""测试拉格朗日点配置"""
# 测试L4点
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
assert len(particles_l4) == 3
# 检查质量
assert particles_l4[0].mass == 1.0 # 太阳
assert particles_l4[1].mass == 3e-6 # 地球
assert particles_l4[2].mass == 1e-8 # 测试质点
# 测试L5点
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
assert len(particles_l5) == 3
# 检查位置L4和L5应该在等边三角形顶点
r_l4 = particles_l4[2].position
r_l5 = particles_l5[2].position
assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0])
assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0])
def test_random_config(self):
"""测试随机配置"""
particles = ThreeBodyConfig.create_random_config()
assert len(particles) == 3
assert all(0.5 <= p.mass <= 2.0 for p in particles)
# 检查总动量接近零(随机配置会调整速度使总动量为零)
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_custom_config(self):
"""测试自定义配置"""
config_dict = {
'particle_1': {
'mass': 1.0,
'position': [1.0, 0.0, 0.0],
'velocity': [0.0, 1.0, 0.0],
'name': 'Star A',
'color': 'red'
},
'particle_2': {
'mass': 2.0,
'position': [-1.0, 0.0, 0.0],
'velocity': [0.0, -0.5, 0.0],
'name': 'Star B',
'color': 'green'
},
'particle_3': {
'mass': 0.5,
'position': [0.0, 1.0, 0.0],
'velocity': [0.5, 0.0, 0.0],
'name': 'Star C',
'color': 'blue'
}
}
particles = ThreeBodyConfig.create_custom_config(config_dict)
assert len(particles) == 3
assert particles[0].name == 'Star A'
assert particles[1].mass == 2.0
assert np.allclose(particles[2].position, [0.0, 1.0, 0.0])
def test_integration_accuracy():
"""测试数值积分精度"""
# 使用简单的二体问题测试(第三个体质量很小)
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]),
Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点
]
# 测试不同时间步长
dts = [0.01, 0.005, 0.001, 0.0005]
energy_errors = []
for dt in dts:
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=1.0, progress_interval=10000)
_, _, energy_error = solver.get_conservation_errors()
energy_errors.append(energy_error)
# 检查误差随步长减小而减小(四阶方法)
for i in range(len(energy_errors)-1):
# 误差应该大致按 dt^4 减小
error_ratio = energy_errors[i] / energy_errors[i+1]
dt_ratio = (dts[i] / dts[i+1])**4
# 允许一定的误差范围
assert 0.1 < error_ratio / dt_ratio < 10.0
if __name__ == "__main__":
# 运行所有测试
print("运行三体问题求解器测试...")
# 创建测试实例
test_particle = TestParticle()
test_solver = TestThreeBodySolver()
test_config = TestThreeBodyConfig()
# 运行粒子测试
print("\n1. 测试质点类:")
test_particle.test_particle_creation()
print(" ✓ test_particle_creation 通过")
test_particle.test_particle_update()
print(" ✓ test_particle_update 通过")
test_particle.test_particle_energy()
print(" ✓ test_particle_energy 通过")
test_particle.test_particle_copy()
print(" ✓ test_particle_copy 通过")
# 运行求解器测试
print("\n2. 测试求解器类:")
test_solver.test_solver_creation()
print(" ✓ test_solver_creation 通过")
try:
test_solver.test_solver_wrong_number_of_particles()
print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常")
except ValueError:
print(" ✓ test_solver_wrong_number_of_particles 通过")
test_solver.test_acceleration_calculation()
print(" ✓ test_acceleration_calculation 通过")
test_solver.test_center_of_mass()
print(" ✓ test_center_of_mass 通过")
test_solver.test_energy_calculation()
print(" ✓ test_energy_calculation 通过")
test_solver.test_single_step()
print(" ✓ test_single_step 通过")
test_solver.test_conservation_laws()
print(" ✓ test_conservation_laws 通过")
# 运行配置测试
print("\n3. 测试配置管理:")
test_config.test_figure8_config()
print(" ✓ test_figure8_config 通过")
test_config.test_lagrange_config()
print(" ✓ test_lagrange_config 通过")
test_config.test_random_config()
print(" ✓ test_random_config 通过")
test_config.test_custom_config()
print(" ✓ test_custom_config 通过")
# 运行积分精度测试
print("\n4. 测试数值积分精度:")
test_integration_accuracy()
print(" ✓ test_integration_accuracy 通过")
print("\n" + "="*60)
print("所有测试通过!")
print("="*60)

View File

@@ -0,0 +1,321 @@
"""
三体问题可视化模块
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Optional, Tuple
from .particle import Particle
from .solver import ThreeBodySolver
class ThreeBodyVisualizer:
"""三体问题可视化类"""
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
"""
初始化可视化器
参数:
figsize: 图形大小
"""
self.figsize = figsize
self.fig = None
self.ax = None
def create_3d_plot(self):
"""创建3D图形"""
self.fig = plt.figure(figsize=self.figsize)
self.ax = self.fig.add_subplot(111, projection='3d')
def plot_trajectories(self, solver: ThreeBodySolver,
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹"):
"""
绘制三体运动轨迹
参数:
solver: 三体问题求解器
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if self.ax is None:
self.create_3d_plot()
trajectories = solver.get_trajectories()
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
self.ax.scatter(com[0], com[1], com[2],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
self.ax.set_xlabel('X (AU)', fontsize=12)
self.ax.set_ylabel('Y (AU)', fontsize=12)
self.ax.set_zlabel('Z (AU)', fontsize=12)
self.ax.set_title(title, fontsize=14, fontweight='bold')
self.ax.legend(fontsize=10)
self.ax.grid(True, alpha=0.3)
# 设置等比例坐标轴
self.ax.set_aspect('auto')
def plot_2d_projection(self, solver: ThreeBodySolver,
projection: str = 'xy',
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹 (2D投影)"):
"""
绘制2D投影
参数:
solver: 三体问题求解器
projection: 投影平面 ('xy', 'xz', 'yz')
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if projection not in ['xy', 'xz', 'yz']:
raise ValueError("projection 必须是 'xy', 'xz''yz'")
fig, ax = plt.subplots(figsize=(10, 8))
trajectories = solver.get_trajectories()
# 确定坐标轴索引
if projection == 'xy':
x_idx, y_idx = 0, 1
x_label, y_label = 'X (AU)', 'Y (AU)'
elif projection == 'xz':
x_idx, y_idx = 0, 2
x_label, y_label = 'X (AU)', 'Z (AU)'
else: # 'yz'
x_idx, y_idx = 1, 2
x_label, y_label = 'Y (AU)', 'Z (AU)'
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
ax.plot(traj[:, x_idx], traj[:, y_idx],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
ax.scatter(current_pos[x_idx], current_pos[y_idx],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
ax.scatter(com[x_idx], com[y_idx],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
ax.set_xlabel(x_label, fontsize=12)
ax.set_ylabel(y_label, fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
return fig, ax
def plot_energy_conservation(self, solver: ThreeBodySolver,
time_points: Optional[np.ndarray] = None):
"""
绘制能量守恒图
参数:
solver: 三体问题求解器
time_points: 时间点数组如果为None则使用模拟时间
"""
# 从历史记录中提取能量信息
# 注意:这里需要修改求解器以记录能量历史
# 暂时返回一个简单的图形
fig, ax = plt.subplots(figsize=(10, 6))
# 这里可以添加能量历史记录功能
ax.set_xlabel('时间 (年)', fontsize=12)
ax.set_ylabel('总能量', fontsize=12)
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def plot_phase_space(self, solver: ThreeBodySolver,
particle_index: int = 0,
dimension: str = 'x'):
"""
绘制相空间图(位置 vs 速度)
参数:
solver: 三体问题求解器
particle_index: 质点索引 (0, 1, 2)
dimension: 维度 ('x', 'y', 'z')
"""
if particle_index not in [0, 1, 2]:
raise ValueError("particle_index 必须是 0, 1, 或 2")
if dimension not in ['x', 'y', 'z']:
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
particle = solver.particles[particle_index]
trajectories = particle.get_trajectory()
if len(trajectories) < 2:
print("警告:轨迹数据不足,无法绘制相空间图")
return None, None
# 获取位置和速度历史
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
# 注意:速度历史需要从求解器中获取
# 暂时使用位置差分近似速度
if len(positions) > 1:
dt = solver.dt
velocities = np.gradient(positions, dt)
else:
velocities = np.zeros_like(positions)
fig, ax = plt.subplots(figsize=(10, 6))
# 绘制相空间轨迹
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
cmap='viridis', alpha=0.7, s=20)
# 添加颜色条表示时间
plt.colorbar(scatter, ax=ax, label='时间步')
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def animate_trajectories(self, solver: ThreeBodySolver,
save_path: Optional[str] = None,
fps: int = 30,
dpi: int = 100):
"""
创建轨迹动画(需要额外安装 matplotlib.animation
参数:
solver: 三体问题求解器
save_path: 保存路径如果为None则显示动画
fps: 帧率
dpi: 分辨率
"""
try:
from matplotlib.animation import FuncAnimation
except ImportError:
print("错误:需要安装 matplotlib.animation 模块")
return
trajectories = solver.get_trajectories()
if any(len(traj) < 2 for traj in trajectories):
print("错误:轨迹数据不足,无法创建动画")
return
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# 初始化图形元素
lines = []
points = []
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 创建轨迹线
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
lines.append(line)
# 创建当前位置点
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
markeredgecolor='black', markeredgewidth=1)
points.append(point)
# 设置坐标轴范围
all_points = np.vstack(trajectories)
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
ax.set_xlabel('X (AU)', fontsize=12)
ax.set_ylabel('Y (AU)', fontsize=12)
ax.set_zlabel('Z (AU)', fontsize=12)
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
# 动画更新函数
def update(frame):
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
# 更新轨迹线(显示到当前帧)
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
line.set_3d_properties(traj[:frame+1, 2])
# 更新当前位置点
if frame < len(traj):
point.set_data([traj[frame, 0]], [traj[frame, 1]])
point.set_3d_properties([traj[frame, 2]])
return lines + points
# 创建动画
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
interval=1000/fps, blit=True)
if save_path:
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
print(f"动画已保存到: {save_path}")
else:
plt.show()
return anim
def show(self):
"""显示所有图形"""
plt.tight_layout()
plt.show()
def save_figure(self, filename: str, dpi: int = 300):
"""保存当前图形"""
if self.fig is not None:
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
print(f"图形已保存到: {filename}")
else:
print("警告:没有活动的图形可以保存")