Files
three-body-problem/three_body_problem/demo.py
dison0331-ThinkPad 8c8ad9fe07 first
pc-1
2026-03-11 21:32:58 +08:00

299 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
三体问题求解器演示脚本
"""
import numpy as np
import matplotlib.pyplot as plt
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
def demo_basic_usage():
"""演示基本用法"""
print("=" * 60)
print("三体问题求解器 - 基本用法演示")
print("=" * 60)
# 1. 创建自定义三体系统
print("\n1. 创建自定义三体系统...")
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0], name="Star A", color="red"),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0], name="Star B", color="green"),
Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0], name="Star C", color="blue")
]
# 2. 创建求解器
print("2. 创建求解器...")
solver = ThreeBodySolver(particles, dt=0.001)
# 3. 模拟5年
print("3. 模拟5年运动...")
solver.simulate(total_time=5.0, progress_interval=1000)
# 4. 分析结果
print("\n4. 分析模拟结果...")
trajectories = solver.get_trajectories()
print(f" 轨迹点数: {len(trajectories[0])}")
com = solver.get_center_of_mass()
print(f" 系统质心: [{com[0]:.3f}, {com[1]:.3f}, {com[2]:.3f}] AU")
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
print(f" 动量误差: {momentum_error:.2e}")
print(f" 角动量误差: {angular_momentum_error:.2e}")
print(f" 能量相对误差: {energy_error:.2e}")
# 5. 可视化
print("\n5. 生成可视化图形...")
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
# 创建3D轨迹图
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="自定义三体系统")
# 保存图形
visualizer.save_figure("demo_basic_3d.png", dpi=300)
# 创建2D投影图
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
fig.savefig("demo_basic_2d.png", dpi=300, bbox_inches='tight')
print("\n图形已保存:")
print(" - demo_basic_3d.png (3D轨迹)")
print(" - demo_basic_2d.png (2D投影)")
# 显示图形
visualizer.show()
return solver
def demo_figure8():
"""演示8字形轨道"""
print("\n" + "=" * 60)
print("8字形轨道演示")
print("=" * 60)
# 使用预置配置
particles = ThreeBodyConfig.create_figure8_config()
# 打印配置信息
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟10年
print("\n模拟10年8字形轨道...")
solver.simulate(total_time=10.0, progress_interval=2000)
# 可视化
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
visualizer.create_3d_plot()
visualizer.plot_trajectories(solver, title="8字形三体轨道")
visualizer.save_figure("demo_figure8.png", dpi=300)
print("\n图形已保存: demo_figure8.png")
visualizer.show()
return solver
def demo_lagrange_points():
"""演示拉格朗日点"""
print("\n" + "=" * 60)
print("拉格朗日点演示")
print("=" * 60)
# 创建L4点配置
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
# 打印配置信息
ThreeBodyConfig.print_config_summary(particles)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟50年
print("\n模拟50年拉格朗日点L4稳定性...")
solver.simulate(total_time=50.0, progress_interval=10000)
# 分析测试质点稳定性
test_particle = solver.particles[2]
trajectory = test_particle.get_trajectory()
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
print(f"\n测试质点稳定性分析:")
print(f" 初始距离L4点: {distances[0]:.6f} AU")
print(f" 最终距离L4点: {distances[-1]:.6f} AU")
print(f" 最大距离偏差: {np.max(distances):.6f} AU")
print(f" 平均距离偏差: {np.mean(distances):.6f} AU")
# 可视化
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="拉格朗日点L4稳定性")
fig.savefig("demo_lagrange_l4.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_lagrange_l4.png")
plt.show()
return solver
def demo_random_systems():
"""演示随机系统"""
print("\n" + "=" * 60)
print("随机三体系统演示")
print("=" * 60)
# 创建3个不同的随机系统
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
for i, seed in enumerate([42, 123, 456]):
np.random.seed(seed)
# 创建随机配置
particles = ThreeBodyConfig.create_random_config(
position_range=1.5,
velocity_scale=2.0
)
# 创建求解器
solver = ThreeBodySolver(particles, dt=0.001)
# 模拟5年
solver.simulate(total_time=5.0, progress_interval=2000)
# 绘制轨迹
trajectories = solver.get_trajectories()
colors = ['red', 'green', 'blue']
ax = axes[i]
for j, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[j % len(colors)]
ax.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5)
ax.scatter(traj[-1, 0], traj[-1, 1], color=color, s=50, edgecolors='black', linewidth=1)
ax.set_xlabel('X (AU)', fontsize=10)
ax.set_ylabel('Y (AU)', fontsize=10)
ax.set_title(f'随机系统 #{i+1} (种子: {seed})', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
plt.suptitle('随机三体系统示例', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("demo_random_systems.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_random_systems.png")
plt.show()
def demo_conservation_laws():
"""演示守恒定律"""
print("\n" + "=" * 60)
print("守恒定律验证演示")
print("=" * 60)
# 使用8字形轨道理论上应该守恒
particles = ThreeBodyConfig.create_figure8_config()
# 测试不同时间步长下的守恒性
time_steps = [0.01, 0.005, 0.001, 0.0005]
total_time = 1.0
results = []
for dt in time_steps:
print(f"\n测试时间步长: {dt}")
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=total_time, progress_interval=1000)
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
results.append({
'dt': dt,
'momentum_error': momentum_error,
'angular_momentum_error': angular_momentum_error,
'energy_error': energy_error
})
print(f" 动量误差: {momentum_error:.2e}")
print(f" 角动量误差: {angular_momentum_error:.2e}")
print(f" 能量相对误差: {energy_error:.2e}")
# 绘制误差图
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
dts = [r['dt'] for r in results]
# 动量误差
axes[0].loglog(dts, [r['momentum_error'] for r in results], 'o-', linewidth=2, markersize=8)
axes[0].set_xlabel('时间步长 (年)', fontsize=12)
axes[0].set_ylabel('动量误差', fontsize=12)
axes[0].set_title('动量守恒', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, which='both')
# 角动量误差
axes[1].loglog(dts, [r['angular_momentum_error'] for r in results], 's-', linewidth=2, markersize=8, color='green')
axes[1].set_xlabel('时间步长 (年)', fontsize=12)
axes[1].set_ylabel('角动量误差', fontsize=12)
axes[1].set_title('角动量守恒', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, which='both')
# 能量误差
axes[2].loglog(dts, [r['energy_error'] for r in results], '^-', linewidth=2, markersize=8, color='red')
axes[2].set_xlabel('时间步长 (年)', fontsize=12)
axes[2].set_ylabel('能量相对误差', fontsize=12)
axes[2].set_title('能量守恒', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3, which='both')
# 添加参考线(四阶精度)
ref_dt = np.array(dts)
ref_error = 1e-4 * (ref_dt / 0.001)**4
axes[2].loglog(ref_dt, ref_error, 'k--', alpha=0.7, label='四阶精度参考线')
axes[2].legend()
plt.suptitle('守恒定律数值验证 (RK4方法)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig("demo_conservation_laws.png", dpi=300, bbox_inches='tight')
print("\n图形已保存: demo_conservation_laws.png")
plt.show()
def main():
"""主函数"""
print("三体问题求解器演示")
print("=" * 60)
# 运行所有演示
try:
# 演示1: 基本用法
solver1 = demo_basic_usage()
# 演示2: 8字形轨道
solver2 = demo_figure8()
# 演示3: 拉格朗日点
solver3 = demo_lagrange_points()
# 演示4: 随机系统
demo_random_systems()
# 演示5: 守恒定律
demo_conservation_laws()
print("\n" + "=" * 60)
print("所有演示完成!")
print("=" * 60)
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()