#!/usr/bin/env python3 """ 三体问题求解器演示脚本 """ import numpy as np import matplotlib.pyplot as plt from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer def demo_basic_usage(): """演示基本用法""" print("=" * 60) print("三体问题求解器 - 基本用法演示") print("=" * 60) # 1. 创建自定义三体系统 print("\n1. 创建自定义三体系统...") particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0], name="Star A", color="red"), Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0], name="Star B", color="green"), Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0], name="Star C", color="blue") ] # 2. 创建求解器 print("2. 创建求解器...") solver = ThreeBodySolver(particles, dt=0.001) # 3. 模拟5年 print("3. 模拟5年运动...") solver.simulate(total_time=5.0, progress_interval=1000) # 4. 分析结果 print("\n4. 分析模拟结果...") trajectories = solver.get_trajectories() print(f" 轨迹点数: {len(trajectories[0])}") com = solver.get_center_of_mass() print(f" 系统质心: [{com[0]:.3f}, {com[1]:.3f}, {com[2]:.3f}] AU") momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors() print(f" 动量误差: {momentum_error:.2e}") print(f" 角动量误差: {angular_momentum_error:.2e}") print(f" 能量相对误差: {energy_error:.2e}") # 5. 可视化 print("\n5. 生成可视化图形...") visualizer = ThreeBodyVisualizer(figsize=(14, 10)) # 创建3D轨迹图 visualizer.create_3d_plot() visualizer.plot_trajectories(solver, title="自定义三体系统") # 保存图形 visualizer.save_figure("demo_basic_3d.png", dpi=300) # 创建2D投影图 fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影") fig.savefig("demo_basic_2d.png", dpi=300, bbox_inches='tight') print("\n图形已保存:") print(" - demo_basic_3d.png (3D轨迹)") print(" - demo_basic_2d.png (2D投影)") # 显示图形 visualizer.show() return solver def demo_figure8(): """演示8字形轨道""" print("\n" + "=" * 60) print("8字形轨道演示") print("=" * 60) # 使用预置配置 particles = ThreeBodyConfig.create_figure8_config() # 打印配置信息 ThreeBodyConfig.print_config_summary(particles) # 创建求解器 solver = ThreeBodySolver(particles, dt=0.001) # 模拟10年 print("\n模拟10年8字形轨道...") solver.simulate(total_time=10.0, progress_interval=2000) # 可视化 visualizer = ThreeBodyVisualizer(figsize=(12, 8)) visualizer.create_3d_plot() visualizer.plot_trajectories(solver, title="8字形三体轨道") visualizer.save_figure("demo_figure8.png", dpi=300) print("\n图形已保存: demo_figure8.png") visualizer.show() return solver def demo_lagrange_points(): """演示拉格朗日点""" print("\n" + "=" * 60) print("拉格朗日点演示") print("=" * 60) # 创建L4点配置 particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4) # 打印配置信息 ThreeBodyConfig.print_config_summary(particles) # 创建求解器 solver = ThreeBodySolver(particles, dt=0.001) # 模拟50年 print("\n模拟50年拉格朗日点L4稳定性...") solver.simulate(total_time=50.0, progress_interval=10000) # 分析测试质点稳定性 test_particle = solver.particles[2] trajectory = test_particle.get_trajectory() lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0]) distances = np.linalg.norm(trajectory - lagrange_position, axis=1) print(f"\n测试质点稳定性分析:") print(f" 初始距离L4点: {distances[0]:.6f} AU") print(f" 最终距离L4点: {distances[-1]:.6f} AU") print(f" 最大距离偏差: {np.max(distances):.6f} AU") print(f" 平均距离偏差: {np.mean(distances):.6f} AU") # 可视化 visualizer = ThreeBodyVisualizer(figsize=(12, 8)) fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="拉格朗日点L4稳定性") fig.savefig("demo_lagrange_l4.png", dpi=300, bbox_inches='tight') print("\n图形已保存: demo_lagrange_l4.png") plt.show() return solver def demo_random_systems(): """演示随机系统""" print("\n" + "=" * 60) print("随机三体系统演示") print("=" * 60) # 创建3个不同的随机系统 fig, axes = plt.subplots(1, 3, figsize=(18, 6)) for i, seed in enumerate([42, 123, 456]): np.random.seed(seed) # 创建随机配置 particles = ThreeBodyConfig.create_random_config( position_range=1.5, velocity_scale=2.0 ) # 创建求解器 solver = ThreeBodySolver(particles, dt=0.001) # 模拟5年 solver.simulate(total_time=5.0, progress_interval=2000) # 绘制轨迹 trajectories = solver.get_trajectories() colors = ['red', 'green', 'blue'] ax = axes[i] for j, (traj, particle) in enumerate(zip(trajectories, solver.particles)): color = particle.color if particle.color else colors[j % len(colors)] ax.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5) ax.scatter(traj[-1, 0], traj[-1, 1], color=color, s=50, edgecolors='black', linewidth=1) ax.set_xlabel('X (AU)', fontsize=10) ax.set_ylabel('Y (AU)', fontsize=10) ax.set_title(f'随机系统 #{i+1} (种子: {seed})', fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3) ax.set_aspect('equal', adjustable='box') plt.suptitle('随机三体系统示例', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig("demo_random_systems.png", dpi=300, bbox_inches='tight') print("\n图形已保存: demo_random_systems.png") plt.show() def demo_conservation_laws(): """演示守恒定律""" print("\n" + "=" * 60) print("守恒定律验证演示") print("=" * 60) # 使用8字形轨道(理论上应该守恒) particles = ThreeBodyConfig.create_figure8_config() # 测试不同时间步长下的守恒性 time_steps = [0.01, 0.005, 0.001, 0.0005] total_time = 1.0 results = [] for dt in time_steps: print(f"\n测试时间步长: {dt}") solver = ThreeBodySolver([p.copy() for p in particles], dt=dt) solver.simulate(total_time=total_time, progress_interval=1000) momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors() results.append({ 'dt': dt, 'momentum_error': momentum_error, 'angular_momentum_error': angular_momentum_error, 'energy_error': energy_error }) print(f" 动量误差: {momentum_error:.2e}") print(f" 角动量误差: {angular_momentum_error:.2e}") print(f" 能量相对误差: {energy_error:.2e}") # 绘制误差图 fig, axes = plt.subplots(1, 3, figsize=(15, 5)) dts = [r['dt'] for r in results] # 动量误差 axes[0].loglog(dts, [r['momentum_error'] for r in results], 'o-', linewidth=2, markersize=8) axes[0].set_xlabel('时间步长 (年)', fontsize=12) axes[0].set_ylabel('动量误差', fontsize=12) axes[0].set_title('动量守恒', fontsize=14, fontweight='bold') axes[0].grid(True, alpha=0.3, which='both') # 角动量误差 axes[1].loglog(dts, [r['angular_momentum_error'] for r in results], 's-', linewidth=2, markersize=8, color='green') axes[1].set_xlabel('时间步长 (年)', fontsize=12) axes[1].set_ylabel('角动量误差', fontsize=12) axes[1].set_title('角动量守恒', fontsize=14, fontweight='bold') axes[1].grid(True, alpha=0.3, which='both') # 能量误差 axes[2].loglog(dts, [r['energy_error'] for r in results], '^-', linewidth=2, markersize=8, color='red') axes[2].set_xlabel('时间步长 (年)', fontsize=12) axes[2].set_ylabel('能量相对误差', fontsize=12) axes[2].set_title('能量守恒', fontsize=14, fontweight='bold') axes[2].grid(True, alpha=0.3, which='both') # 添加参考线(四阶精度) ref_dt = np.array(dts) ref_error = 1e-4 * (ref_dt / 0.001)**4 axes[2].loglog(ref_dt, ref_error, 'k--', alpha=0.7, label='四阶精度参考线') axes[2].legend() plt.suptitle('守恒定律数值验证 (RK4方法)', fontsize=16, fontweight='bold') plt.tight_layout() plt.savefig("demo_conservation_laws.png", dpi=300, bbox_inches='tight') print("\n图形已保存: demo_conservation_laws.png") plt.show() def main(): """主函数""" print("三体问题求解器演示") print("=" * 60) # 运行所有演示 try: # 演示1: 基本用法 solver1 = demo_basic_usage() # 演示2: 8字形轨道 solver2 = demo_figure8() # 演示3: 拉格朗日点 solver3 = demo_lagrange_points() # 演示4: 随机系统 demo_random_systems() # 演示5: 守恒定律 demo_conservation_laws() print("\n" + "=" * 60) print("所有演示完成!") print("=" * 60) except Exception as e: print(f"\n错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()