299 lines
9.7 KiB
Python
299 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
三体问题求解器演示脚本
|
||
"""
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
|
||
|
||
|
||
def demo_basic_usage():
|
||
"""演示基本用法"""
|
||
print("=" * 60)
|
||
print("三体问题求解器 - 基本用法演示")
|
||
print("=" * 60)
|
||
|
||
# 1. 创建自定义三体系统
|
||
print("\n1. 创建自定义三体系统...")
|
||
particles = [
|
||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0], name="Star A", color="red"),
|
||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0], name="Star B", color="green"),
|
||
Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0], name="Star C", color="blue")
|
||
]
|
||
|
||
# 2. 创建求解器
|
||
print("2. 创建求解器...")
|
||
solver = ThreeBodySolver(particles, dt=0.001)
|
||
|
||
# 3. 模拟5年
|
||
print("3. 模拟5年运动...")
|
||
solver.simulate(total_time=5.0, progress_interval=1000)
|
||
|
||
# 4. 分析结果
|
||
print("\n4. 分析模拟结果...")
|
||
trajectories = solver.get_trajectories()
|
||
print(f" 轨迹点数: {len(trajectories[0])}")
|
||
|
||
com = solver.get_center_of_mass()
|
||
print(f" 系统质心: [{com[0]:.3f}, {com[1]:.3f}, {com[2]:.3f}] AU")
|
||
|
||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||
print(f" 动量误差: {momentum_error:.2e}")
|
||
print(f" 角动量误差: {angular_momentum_error:.2e}")
|
||
print(f" 能量相对误差: {energy_error:.2e}")
|
||
|
||
# 5. 可视化
|
||
print("\n5. 生成可视化图形...")
|
||
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
|
||
|
||
# 创建3D轨迹图
|
||
visualizer.create_3d_plot()
|
||
visualizer.plot_trajectories(solver, title="自定义三体系统")
|
||
|
||
# 保存图形
|
||
visualizer.save_figure("demo_basic_3d.png", dpi=300)
|
||
|
||
# 创建2D投影图
|
||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
|
||
fig.savefig("demo_basic_2d.png", dpi=300, bbox_inches='tight')
|
||
|
||
print("\n图形已保存:")
|
||
print(" - demo_basic_3d.png (3D轨迹)")
|
||
print(" - demo_basic_2d.png (2D投影)")
|
||
|
||
# 显示图形
|
||
visualizer.show()
|
||
|
||
return solver
|
||
|
||
|
||
def demo_figure8():
|
||
"""演示8字形轨道"""
|
||
print("\n" + "=" * 60)
|
||
print("8字形轨道演示")
|
||
print("=" * 60)
|
||
|
||
# 使用预置配置
|
||
particles = ThreeBodyConfig.create_figure8_config()
|
||
|
||
# 打印配置信息
|
||
ThreeBodyConfig.print_config_summary(particles)
|
||
|
||
# 创建求解器
|
||
solver = ThreeBodySolver(particles, dt=0.001)
|
||
|
||
# 模拟10年
|
||
print("\n模拟10年8字形轨道...")
|
||
solver.simulate(total_time=10.0, progress_interval=2000)
|
||
|
||
# 可视化
|
||
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
|
||
visualizer.create_3d_plot()
|
||
visualizer.plot_trajectories(solver, title="8字形三体轨道")
|
||
visualizer.save_figure("demo_figure8.png", dpi=300)
|
||
|
||
print("\n图形已保存: demo_figure8.png")
|
||
visualizer.show()
|
||
|
||
return solver
|
||
|
||
|
||
def demo_lagrange_points():
|
||
"""演示拉格朗日点"""
|
||
print("\n" + "=" * 60)
|
||
print("拉格朗日点演示")
|
||
print("=" * 60)
|
||
|
||
# 创建L4点配置
|
||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||
|
||
# 打印配置信息
|
||
ThreeBodyConfig.print_config_summary(particles)
|
||
|
||
# 创建求解器
|
||
solver = ThreeBodySolver(particles, dt=0.001)
|
||
|
||
# 模拟50年
|
||
print("\n模拟50年拉格朗日点L4稳定性...")
|
||
solver.simulate(total_time=50.0, progress_interval=10000)
|
||
|
||
# 分析测试质点稳定性
|
||
test_particle = solver.particles[2]
|
||
trajectory = test_particle.get_trajectory()
|
||
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
|
||
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
|
||
|
||
print(f"\n测试质点稳定性分析:")
|
||
print(f" 初始距离L4点: {distances[0]:.6f} AU")
|
||
print(f" 最终距离L4点: {distances[-1]:.6f} AU")
|
||
print(f" 最大距离偏差: {np.max(distances):.6f} AU")
|
||
print(f" 平均距离偏差: {np.mean(distances):.6f} AU")
|
||
|
||
# 可视化
|
||
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
|
||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="拉格朗日点L4稳定性")
|
||
fig.savefig("demo_lagrange_l4.png", dpi=300, bbox_inches='tight')
|
||
|
||
print("\n图形已保存: demo_lagrange_l4.png")
|
||
plt.show()
|
||
|
||
return solver
|
||
|
||
|
||
def demo_random_systems():
|
||
"""演示随机系统"""
|
||
print("\n" + "=" * 60)
|
||
print("随机三体系统演示")
|
||
print("=" * 60)
|
||
|
||
# 创建3个不同的随机系统
|
||
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
||
|
||
for i, seed in enumerate([42, 123, 456]):
|
||
np.random.seed(seed)
|
||
|
||
# 创建随机配置
|
||
particles = ThreeBodyConfig.create_random_config(
|
||
position_range=1.5,
|
||
velocity_scale=2.0
|
||
)
|
||
|
||
# 创建求解器
|
||
solver = ThreeBodySolver(particles, dt=0.001)
|
||
|
||
# 模拟5年
|
||
solver.simulate(total_time=5.0, progress_interval=2000)
|
||
|
||
# 绘制轨迹
|
||
trajectories = solver.get_trajectories()
|
||
colors = ['red', 'green', 'blue']
|
||
|
||
ax = axes[i]
|
||
for j, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||
color = particle.color if particle.color else colors[j % len(colors)]
|
||
ax.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5)
|
||
ax.scatter(traj[-1, 0], traj[-1, 1], color=color, s=50, edgecolors='black', linewidth=1)
|
||
|
||
ax.set_xlabel('X (AU)', fontsize=10)
|
||
ax.set_ylabel('Y (AU)', fontsize=10)
|
||
ax.set_title(f'随机系统 #{i+1} (种子: {seed})', fontsize=12, fontweight='bold')
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_aspect('equal', adjustable='box')
|
||
|
||
plt.suptitle('随机三体系统示例', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
plt.savefig("demo_random_systems.png", dpi=300, bbox_inches='tight')
|
||
|
||
print("\n图形已保存: demo_random_systems.png")
|
||
plt.show()
|
||
|
||
|
||
def demo_conservation_laws():
|
||
"""演示守恒定律"""
|
||
print("\n" + "=" * 60)
|
||
print("守恒定律验证演示")
|
||
print("=" * 60)
|
||
|
||
# 使用8字形轨道(理论上应该守恒)
|
||
particles = ThreeBodyConfig.create_figure8_config()
|
||
|
||
# 测试不同时间步长下的守恒性
|
||
time_steps = [0.01, 0.005, 0.001, 0.0005]
|
||
total_time = 1.0
|
||
|
||
results = []
|
||
|
||
for dt in time_steps:
|
||
print(f"\n测试时间步长: {dt}")
|
||
|
||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||
solver.simulate(total_time=total_time, progress_interval=1000)
|
||
|
||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||
|
||
results.append({
|
||
'dt': dt,
|
||
'momentum_error': momentum_error,
|
||
'angular_momentum_error': angular_momentum_error,
|
||
'energy_error': energy_error
|
||
})
|
||
|
||
print(f" 动量误差: {momentum_error:.2e}")
|
||
print(f" 角动量误差: {angular_momentum_error:.2e}")
|
||
print(f" 能量相对误差: {energy_error:.2e}")
|
||
|
||
# 绘制误差图
|
||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||
|
||
dts = [r['dt'] for r in results]
|
||
|
||
# 动量误差
|
||
axes[0].loglog(dts, [r['momentum_error'] for r in results], 'o-', linewidth=2, markersize=8)
|
||
axes[0].set_xlabel('时间步长 (年)', fontsize=12)
|
||
axes[0].set_ylabel('动量误差', fontsize=12)
|
||
axes[0].set_title('动量守恒', fontsize=14, fontweight='bold')
|
||
axes[0].grid(True, alpha=0.3, which='both')
|
||
|
||
# 角动量误差
|
||
axes[1].loglog(dts, [r['angular_momentum_error'] for r in results], 's-', linewidth=2, markersize=8, color='green')
|
||
axes[1].set_xlabel('时间步长 (年)', fontsize=12)
|
||
axes[1].set_ylabel('角动量误差', fontsize=12)
|
||
axes[1].set_title('角动量守恒', fontsize=14, fontweight='bold')
|
||
axes[1].grid(True, alpha=0.3, which='both')
|
||
|
||
# 能量误差
|
||
axes[2].loglog(dts, [r['energy_error'] for r in results], '^-', linewidth=2, markersize=8, color='red')
|
||
axes[2].set_xlabel('时间步长 (年)', fontsize=12)
|
||
axes[2].set_ylabel('能量相对误差', fontsize=12)
|
||
axes[2].set_title('能量守恒', fontsize=14, fontweight='bold')
|
||
axes[2].grid(True, alpha=0.3, which='both')
|
||
|
||
# 添加参考线(四阶精度)
|
||
ref_dt = np.array(dts)
|
||
ref_error = 1e-4 * (ref_dt / 0.001)**4
|
||
axes[2].loglog(ref_dt, ref_error, 'k--', alpha=0.7, label='四阶精度参考线')
|
||
axes[2].legend()
|
||
|
||
plt.suptitle('守恒定律数值验证 (RK4方法)', fontsize=16, fontweight='bold')
|
||
plt.tight_layout()
|
||
plt.savefig("demo_conservation_laws.png", dpi=300, bbox_inches='tight')
|
||
|
||
print("\n图形已保存: demo_conservation_laws.png")
|
||
plt.show()
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("三体问题求解器演示")
|
||
print("=" * 60)
|
||
|
||
# 运行所有演示
|
||
try:
|
||
# 演示1: 基本用法
|
||
solver1 = demo_basic_usage()
|
||
|
||
# 演示2: 8字形轨道
|
||
solver2 = demo_figure8()
|
||
|
||
# 演示3: 拉格朗日点
|
||
solver3 = demo_lagrange_points()
|
||
|
||
# 演示4: 随机系统
|
||
demo_random_systems()
|
||
|
||
# 演示5: 守恒定律
|
||
demo_conservation_laws()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("所有演示完成!")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"\n错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |