first
pc-1
This commit is contained in:
309
INSTALL.md
Normal file
309
INSTALL.md
Normal file
@@ -0,0 +1,309 @@
|
||||
# 三体问题求解器 - 安装和使用指南
|
||||
|
||||
## 项目简介
|
||||
|
||||
这是一个纯Python实现的三体问题求解器,使用四阶龙格-库塔法(RK4)数值求解牛顿引力下的三体运动。项目提供了完整的模拟、可视化和分析功能。
|
||||
|
||||
## 安装步骤
|
||||
|
||||
### 1. 克隆或下载项目
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone <repository-url>
|
||||
cd three_body_problem
|
||||
|
||||
# 或直接下载zip文件并解压
|
||||
```
|
||||
|
||||
### 2. 安装依赖
|
||||
```bash
|
||||
# 使用pip安装
|
||||
pip install numpy matplotlib
|
||||
|
||||
# 或使用requirements.txt
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 3. 验证安装
|
||||
```bash
|
||||
# 运行简单测试
|
||||
python simple_test.py
|
||||
|
||||
# 或运行完整测试
|
||||
python three_body_problem/tests/test_solver.py
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 基本示例
|
||||
```python
|
||||
import numpy as np
|
||||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
# 创建三个质点
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0], name="Star A"),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0], name="Star B"),
|
||||
Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.3, 0.0, 0.0], name="Star C")
|
||||
]
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟5年
|
||||
solver.simulate(total_time=5.0)
|
||||
|
||||
# 可视化
|
||||
visualizer = ThreeBodyVisualizer()
|
||||
visualizer.plot_trajectories(solver, title="三体系统运动轨迹")
|
||||
visualizer.show()
|
||||
```
|
||||
|
||||
### 使用预置配置
|
||||
```python
|
||||
# 8字形轨道(著名的稳定解)
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 拉格朗日点L4
|
||||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
|
||||
# 随机系统
|
||||
particles = ThreeBodyConfig.create_random_config(
|
||||
masses=None, # 随机质量
|
||||
position_range=2.0, # 位置范围 ±2 AU
|
||||
velocity_scale=1.0 # 速度缩放因子
|
||||
)
|
||||
```
|
||||
|
||||
## 运行示例
|
||||
|
||||
### 1. 运行简单示例
|
||||
```bash
|
||||
python three_body_problem/run_example.py
|
||||
```
|
||||
|
||||
### 2. 运行8字形轨道示例
|
||||
```bash
|
||||
python three_body_problem/examples/figure8.py
|
||||
```
|
||||
|
||||
### 3. 运行拉格朗日点示例
|
||||
```bash
|
||||
python three_body_problem/examples/lagrange.py
|
||||
```
|
||||
|
||||
### 4. 运行随机系统示例
|
||||
```bash
|
||||
python three_body_problem/examples/random.py
|
||||
```
|
||||
|
||||
### 5. 运行演示脚本
|
||||
```bash
|
||||
python three_body_problem/demo.py
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
three_body_problem/
|
||||
├── __init__.py # 包初始化
|
||||
├── particle.py # 质点类
|
||||
├── integrator.py # 数值积分器(RK4)
|
||||
├── solver.py # 三体问题求解器
|
||||
├── visualizer.py # 可视化工具
|
||||
├── config.py # 配置管理
|
||||
├── README.md # 详细文档
|
||||
├── SUMMARY.md # 项目总结
|
||||
├── demo.py # 演示脚本
|
||||
├── run_example.py # 快速示例
|
||||
├── requirements.txt # 依赖列表
|
||||
├── setup.py # 安装配置
|
||||
├── examples/ # 示例
|
||||
│ ├── figure8.py # 8字形轨道
|
||||
│ ├── lagrange.py # 拉格朗日点
|
||||
│ └── random.py # 随机系统
|
||||
└── tests/ # 测试
|
||||
└── test_solver.py # 单元测试
|
||||
```
|
||||
|
||||
## 主要功能
|
||||
|
||||
### 1. 物理模拟
|
||||
- **牛顿引力计算**:精确计算三个质点间的万有引力
|
||||
- **数值积分**:四阶龙格-库塔法(RK4)
|
||||
- **守恒定律**:动量、角动量、能量守恒验证
|
||||
|
||||
### 2. 初始条件
|
||||
- **8字形轨道**:著名的稳定周期解
|
||||
- **拉格朗日点**:L4和L5点稳定性测试
|
||||
- **随机系统**:随机初始条件生成
|
||||
- **自定义配置**:灵活的用户定义
|
||||
|
||||
### 3. 可视化
|
||||
- **3D轨迹图**:完整的三维运动轨迹
|
||||
- **2D投影**:XY、XZ、YZ平面投影
|
||||
- **相空间图**:位置-速度关系
|
||||
- **能量分析**:守恒定律验证
|
||||
|
||||
### 4. 分析工具
|
||||
- **质心计算**:系统质心位置和轨迹
|
||||
- **距离分析**:质点间距离变化
|
||||
- **稳定性分析**:轨道稳定性评估
|
||||
- **误差分析**:数值积分精度评估
|
||||
|
||||
## 配置参数
|
||||
|
||||
### 时间步长选择
|
||||
```python
|
||||
# 高精度模拟(推荐)
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 快速模拟
|
||||
solver = ThreeBodySolver(particles, dt=0.01)
|
||||
|
||||
# 超高精度模拟
|
||||
solver = ThreeBodySolver(particles, dt=0.0001)
|
||||
```
|
||||
|
||||
### 模拟时间
|
||||
```python
|
||||
# 短期模拟(几到几十年)
|
||||
solver.simulate(total_time=10.0) # 10年
|
||||
|
||||
# 长期模拟(几百年)
|
||||
solver.simulate(total_time=100.0) # 100年
|
||||
|
||||
# 超长期模拟(几千年)
|
||||
solver.simulate(total_time=1000.0) # 1000年
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 1. 导入错误
|
||||
**问题**:`ModuleNotFoundError: No module named 'numpy'`
|
||||
**解决**:安装依赖 `pip install numpy matplotlib`
|
||||
|
||||
### 2. 模拟速度慢
|
||||
**问题**:模拟时间太长或时间步长太小
|
||||
**解决**:
|
||||
- 减少模拟时间 `total_time`
|
||||
- 增大时间步长 `dt`
|
||||
- 减少进度打印频率 `progress_interval`
|
||||
|
||||
### 3. 数值不稳定
|
||||
**问题**:质点间距离过小导致计算溢出
|
||||
**解决**:
|
||||
- 增大时间步长 `dt`
|
||||
- 调整初始条件避免近距离接近
|
||||
- 使用更小的质量差异
|
||||
|
||||
### 4. 内存不足
|
||||
**问题**:长时间模拟产生大量轨迹数据
|
||||
**解决**:
|
||||
- 减少模拟时间
|
||||
- 增大时间步长
|
||||
- 修改代码只保存关键时间点
|
||||
|
||||
## 性能优化
|
||||
|
||||
### 1. 减少输出
|
||||
```python
|
||||
# 减少进度打印频率
|
||||
solver.simulate(total_time=100.0, progress_interval=10000)
|
||||
```
|
||||
|
||||
### 2. 调整精度
|
||||
```python
|
||||
# 平衡精度和速度
|
||||
dt = 0.001 # 标准精度
|
||||
dt = 0.01 # 较低精度,更快
|
||||
dt = 0.0001 # 高精度,较慢
|
||||
```
|
||||
|
||||
### 3. 内存管理
|
||||
```python
|
||||
# 定期清理历史记录
|
||||
solver.reset() # 清除所有历史记录
|
||||
```
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的初始条件
|
||||
```python
|
||||
from three_body_problem.config import ThreeBodyConfig
|
||||
|
||||
class MyConfig(ThreeBodyConfig):
|
||||
@staticmethod
|
||||
def create_my_config():
|
||||
# 实现自定义配置
|
||||
particles = [...]
|
||||
return particles
|
||||
```
|
||||
|
||||
### 自定义可视化
|
||||
```python
|
||||
from three_body_problem.visualizer import ThreeBodyVisualizer
|
||||
|
||||
class MyVisualizer(ThreeBodyVisualizer):
|
||||
def plot_custom_view(self, solver):
|
||||
# 实现自定义可视化
|
||||
pass
|
||||
```
|
||||
|
||||
### 实现新的积分器
|
||||
```python
|
||||
from three_body_problem.integrator import RK4Integrator
|
||||
|
||||
class MyIntegrator(RK4Integrator):
|
||||
def step(self, particles, acceleration_func):
|
||||
# 实现新的积分方法
|
||||
pass
|
||||
```
|
||||
|
||||
## 学习资源
|
||||
|
||||
### 三体问题理论
|
||||
1. **经典文献**:Poincaré, H. (1890). "Sur le problème des trois corps"
|
||||
2. **现代研究**:Chenciner, A., & Montgomery, R. (2000). "A remarkable periodic solution"
|
||||
3. **教科书**:Murray, C. D., & Dermott, S. F. (1999). "Solar System Dynamics"
|
||||
|
||||
### 数值方法
|
||||
1. **数值分析**:Hairer, E., et al. (1993). "Solving Ordinary Differential Equations I"
|
||||
2. **科学计算**:Press, W. H., et al. (2007). "Numerical Recipes"
|
||||
|
||||
### Python科学计算
|
||||
1. **NumPy教程**:https://numpy.org/doc/stable/user/quickstart.html
|
||||
2. **Matplotlib教程**:https://matplotlib.org/stable/tutorials/index.html
|
||||
3. **SciPy教程**:https://docs.scipy.org/doc/scipy/tutorial/index.html
|
||||
|
||||
## 贡献指南
|
||||
|
||||
1. Fork项目仓库
|
||||
2. 创建功能分支
|
||||
3. 提交更改
|
||||
4. 推送到分支
|
||||
5. 创建Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License - 详见LICENSE文件
|
||||
|
||||
## 支持
|
||||
|
||||
如有问题或建议,请:
|
||||
1. 查看文档和示例
|
||||
2. 提交Issue
|
||||
3. 联系作者
|
||||
|
||||
## 版本历史
|
||||
|
||||
### v1.0.0 (2024)
|
||||
- 初始版本发布
|
||||
- 实现RK4数值积分器
|
||||
- 提供多种初始条件配置
|
||||
- 完整的可视化功能
|
||||
- 包含测试和示例
|
||||
|
||||
---
|
||||
|
||||
**开始探索三体问题的奇妙世界吧!** 🚀
|
||||
312
README.md
Normal file
312
README.md
Normal file
@@ -0,0 +1,312 @@
|
||||
# 三体问题求解器 - 纯Python方案
|
||||
|
||||
一个用于模拟和可视化三体问题的Python库。使用四阶龙格-库塔法数值求解牛顿引力下的三体运动。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **纯Python实现**:无需外部依赖(除可视化外)
|
||||
- **高精度数值积分**:使用四阶龙格-库塔法(RK4)
|
||||
- **多种初始条件**:预置8字形轨道、拉格朗日点等经典配置
|
||||
- **完整可视化**:3D轨迹、2D投影、相空间图、能量分析
|
||||
- **物理守恒验证**:动量、角动量、能量守恒检查
|
||||
- **模块化设计**:易于扩展和自定义
|
||||
|
||||
## 安装要求
|
||||
|
||||
### 必需依赖
|
||||
- Python 3.7+
|
||||
- NumPy
|
||||
- Matplotlib
|
||||
|
||||
### 安装方法
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone <repository-url>
|
||||
cd three_body_problem
|
||||
|
||||
# 安装依赖
|
||||
pip install numpy matplotlib
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 基本使用
|
||||
```python
|
||||
import numpy as np
|
||||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
# 创建8字形轨道配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟10年
|
||||
solver.simulate(total_time=10.0)
|
||||
|
||||
# 可视化
|
||||
visualizer = ThreeBodyVisualizer()
|
||||
visualizer.plot_trajectories(solver, title="8字形三体轨道")
|
||||
visualizer.show()
|
||||
```
|
||||
|
||||
### 运行示例
|
||||
```python
|
||||
# 运行8字形轨道示例
|
||||
from three_body_problem.examples import figure8
|
||||
solver = figure8.run_figure8_example()
|
||||
|
||||
# 运行拉格朗日点示例
|
||||
from three_body_problem.examples import lagrange
|
||||
solver_l4 = lagrange.run_lagrange_example(lagrange_point=4, total_time=50.0)
|
||||
|
||||
# 运行随机系统示例
|
||||
from three_body_problem.examples import random
|
||||
solver_random = random.run_random_example(seed=42, total_time=15.0)
|
||||
```
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 1. Particle(质点类)
|
||||
表示三体问题中的一个天体。
|
||||
|
||||
```python
|
||||
# 创建质点
|
||||
particle = Particle(
|
||||
mass=1.0, # 质量(太阳质量)
|
||||
position=[1.0, 0.0, 0.0], # 位置向量 [AU]
|
||||
velocity=[0.0, 6.28, 0.0], # 速度向量 [AU/年]
|
||||
name="Earth", # 名称
|
||||
color="blue" # 可视化颜色
|
||||
)
|
||||
```
|
||||
|
||||
### 2. ThreeBodySolver(求解器类)
|
||||
主求解器,使用RK4方法积分运动方程。
|
||||
|
||||
```python
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟一段时间
|
||||
solver.simulate(total_time=10.0, progress_interval=1000)
|
||||
|
||||
# 获取轨迹
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 计算质心
|
||||
center_of_mass = solver.get_center_of_mass()
|
||||
|
||||
# 检查守恒定律
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
```
|
||||
|
||||
### 3. ThreeBodyConfig(配置管理)
|
||||
提供多种预置初始条件。
|
||||
|
||||
```python
|
||||
# 8字形轨道(著名的稳定解)
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 拉格朗日点配置(L4或L5)
|
||||
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
|
||||
|
||||
# 随机配置
|
||||
particles_random = ThreeBodyConfig.create_random_config(
|
||||
masses=None, # 随机质量
|
||||
position_range=2.0, # 位置范围 ±2 AU
|
||||
velocity_scale=1.0 # 速度缩放因子
|
||||
)
|
||||
|
||||
# 自定义配置
|
||||
config_dict = {
|
||||
'particle_1': {'mass': 1.0, 'position': [1,0,0], 'velocity': [0,1,0], 'name': 'Star A'},
|
||||
'particle_2': {'mass': 1.0, 'position': [-1,0,0], 'velocity': [0,-1,0], 'name': 'Star B'},
|
||||
'particle_3': {'mass': 0.1, 'position': [0,1,0], 'velocity': [-1,0,0], 'name': 'Star C'}
|
||||
}
|
||||
particles_custom = ThreeBodyConfig.create_custom_config(config_dict)
|
||||
```
|
||||
|
||||
### 4. ThreeBodyVisualizer(可视化类)
|
||||
提供多种可视化选项。
|
||||
|
||||
```python
|
||||
visualizer = ThreeBodyVisualizer(figsize=(12, 10))
|
||||
|
||||
# 3D轨迹图
|
||||
visualizer.plot_trajectories(solver, show_current_positions=True, show_com=True)
|
||||
|
||||
# 2D投影
|
||||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy')
|
||||
|
||||
# 相空间图
|
||||
fig, ax = visualizer.plot_phase_space(solver, particle_index=0, dimension='x')
|
||||
|
||||
# 显示图形
|
||||
visualizer.show()
|
||||
|
||||
# 保存图形
|
||||
visualizer.save_figure("trajectory.png", dpi=300)
|
||||
```
|
||||
|
||||
## 物理模型
|
||||
|
||||
### 运动方程
|
||||
对于三个质点 $i=1,2,3$,每个质点的加速度为:
|
||||
|
||||
$$
|
||||
\vec{a}_i = G \sum_{j \neq i} \frac{m_j}{|\vec{r}_{ij}|^3} \vec{r}_{ij}
|
||||
$$
|
||||
|
||||
其中:
|
||||
- $G = 4\pi^2$ (天文单位制)
|
||||
- $\vec{r}_{ij} = \vec{r}_j - \vec{r}_i$
|
||||
- $m_j$ 是质点 $j$ 的质量
|
||||
|
||||
### 单位系统
|
||||
- **距离**:天文单位(AU)
|
||||
- **质量**:太阳质量($M_\odot$)
|
||||
- **时间**:年(yr)
|
||||
- **速度**:AU/yr
|
||||
- **引力常数**:$G = 4\pi^2$ AU³/(M⊙·yr²)
|
||||
|
||||
### 数值方法
|
||||
使用四阶龙格-库塔法(RK4)积分运动方程:
|
||||
- 时间步长:`dt`(默认0.001年)
|
||||
- 状态向量:18维(3个质点 × 6个自由度)
|
||||
|
||||
## 示例
|
||||
|
||||
### 示例1:8字形轨道
|
||||
```python
|
||||
from three_body_problem.examples import figure8
|
||||
|
||||
# 运行8字形轨道示例
|
||||
solver = figure8.run_figure8_example()
|
||||
|
||||
# 分析稳定性
|
||||
figure8.analyze_figure8_stability()
|
||||
```
|
||||
|
||||
### 示例2:拉格朗日点
|
||||
```python
|
||||
from three_body_problem.examples import lagrange
|
||||
|
||||
# 运行L4点示例
|
||||
solver_l4 = lagrange.run_lagrange_example(lagrange_point=4, total_time=100.0)
|
||||
|
||||
# 运行L5点示例
|
||||
solver_l5 = lagrange.run_lagrange_example(lagrange_point=5, total_time=100.0)
|
||||
|
||||
# 比较L4和L5稳定性
|
||||
lagrange.compare_lagrange_points()
|
||||
```
|
||||
|
||||
### 示例3:随机系统
|
||||
```python
|
||||
from three_body_problem.examples import random
|
||||
|
||||
# 运行单个随机系统
|
||||
solver = random.run_random_example(seed=42, total_time=20.0)
|
||||
|
||||
# 运行多个随机系统比较
|
||||
results = random.run_multiple_random_simulations(n_simulations=5, total_time=10.0)
|
||||
```
|
||||
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
cd three_body_problem
|
||||
python -m pytest tests/ -v
|
||||
|
||||
# 或直接运行测试脚本
|
||||
python tests/test_solver.py
|
||||
```
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
1. **时间步长选择**:
|
||||
- 对于稳定轨道:`dt = 0.001`(默认)
|
||||
- 对于快速运动系统:`dt = 0.0001`
|
||||
- 对于长期模拟:`dt = 0.01`
|
||||
|
||||
2. **内存管理**:
|
||||
- 长时间模拟时考虑定期清理轨迹历史
|
||||
- 使用`reset()`方法清除历史记录
|
||||
|
||||
3. **精度控制**:
|
||||
- 检查能量守恒误差:应小于1e-5
|
||||
- 检查动量守恒误差:应小于1e-10
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的初始条件
|
||||
```python
|
||||
from three_body_problem.config import ThreeBodyConfig
|
||||
|
||||
class MyCustomConfig(ThreeBodyConfig):
|
||||
@staticmethod
|
||||
def create_my_config():
|
||||
# 实现自定义配置
|
||||
particles = [...]
|
||||
return particles
|
||||
```
|
||||
|
||||
### 自定义可视化
|
||||
```python
|
||||
from three_body_problem.visualizer import ThreeBodyVisualizer
|
||||
|
||||
class MyVisualizer(ThreeBodyVisualizer):
|
||||
def plot_custom_view(self, solver):
|
||||
# 实现自定义可视化
|
||||
pass
|
||||
```
|
||||
|
||||
### 实现新的积分器
|
||||
```python
|
||||
from three_body_problem.integrator import RK4Integrator
|
||||
|
||||
class MyIntegrator(RK4Integrator):
|
||||
def step(self, particles, acceleration_func):
|
||||
# 实现新的积分方法
|
||||
pass
|
||||
```
|
||||
|
||||
## 已知问题与限制
|
||||
|
||||
1. **数值稳定性**:
|
||||
- 近距离接近可能导致数值不稳定
|
||||
- 建议使用较小时间步长`dt`
|
||||
|
||||
2. **能量漂移**:
|
||||
- 长期模拟可能出现能量漂移
|
||||
- 使用辛积分器可改善(未来版本)
|
||||
|
||||
3. **性能**:
|
||||
- 纯Python实现,性能有限
|
||||
- 对于大规模模拟考虑使用Numba或Cython加速
|
||||
|
||||
## 参考文献
|
||||
|
||||
1. Chenciner, A., & Montgomery, R. (2000). A remarkable periodic solution of the three-body problem in the case of equal masses.
|
||||
2. Murray, C. D., & Dermott, S. F. (1999). Solar System Dynamics.
|
||||
3. Hairer, E., Nørsett, S. P., & Wanner, G. (1993). Solving Ordinary Differential Equations I.
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎提交Issue和Pull Request!
|
||||
|
||||
|
||||
## 版本历史
|
||||
|
||||
- v1.0.0 (2024) - 初始版本
|
||||
- 实现RK4数值积分器
|
||||
- 提供多种初始条件配置
|
||||
- 完整的可视化功能
|
||||
- 包含测试和示例
|
||||
251
SUMMARY.md
Normal file
251
SUMMARY.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# 三体问题求解器 - 项目总结
|
||||
|
||||
## 项目概述
|
||||
|
||||
这是一个纯Python实现的三体问题求解器,使用四阶龙格-库塔法(RK4)数值求解牛顿引力下的三体运动。项目提供了完整的模拟、可视化和分析功能。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
three_body_problem/
|
||||
├── __init__.py # 包初始化文件
|
||||
├── particle.py # 质点类定义
|
||||
├── integrator.py # 数值积分器(RK4方法)
|
||||
├── solver.py # 三体问题求解器主类
|
||||
├── visualizer.py # 可视化工具
|
||||
├── config.py # 配置管理
|
||||
├── README.md # 使用说明文档
|
||||
├── SUMMARY.md # 项目总结文档
|
||||
├── demo.py # 演示脚本
|
||||
├── run_example.py # 快速示例脚本
|
||||
├── requirements.txt # 依赖列表
|
||||
├── setup.py # 安装配置
|
||||
├── examples/ # 示例配置
|
||||
│ ├── __init__.py
|
||||
│ ├── figure8.py # 8字形轨道示例
|
||||
│ ├── lagrange.py # 拉格朗日点示例
|
||||
│ └── random.py # 随机初始条件示例
|
||||
└── tests/ # 测试文件
|
||||
├── __init__.py
|
||||
└── test_solver.py # 单元测试
|
||||
```
|
||||
|
||||
## 核心功能
|
||||
|
||||
### 1. 物理模型
|
||||
- **牛顿万有引力定律**:$F = G \frac{m_1 m_2}{r^2}$
|
||||
- **运动方程**:$m_i \frac{d^2 \vec{r}_i}{dt^2} = \sum_{j \neq i} G \frac{m_i m_j}{|\vec{r}_j - \vec{r}_i|^3} (\vec{r}_j - \vec{r}_i)$
|
||||
- **单位系统**:天文单位(AU)、太阳质量(M⊙)、年(yr)
|
||||
|
||||
### 2. 数值方法
|
||||
- **四阶龙格-库塔法(RK4)**:高精度数值积分
|
||||
- **自适应时间步长**:支持不同精度需求
|
||||
- **守恒定律验证**:动量、角动量、能量守恒检查
|
||||
|
||||
### 3. 预置配置
|
||||
- **8字形轨道**:著名的稳定三体轨道
|
||||
- **拉格朗日点**:L4和L5点稳定性测试
|
||||
- **随机系统**:随机初始条件生成
|
||||
- **双星系统**:双星+测试质点配置
|
||||
- **自定义配置**:灵活的用户定义
|
||||
|
||||
### 4. 可视化功能
|
||||
- **3D轨迹图**:完整的三维运动轨迹
|
||||
- **2D投影图**:XY、XZ、YZ平面投影
|
||||
- **相空间图**:位置-速度关系分析
|
||||
- **能量分析**:守恒定律验证
|
||||
- **动画支持**:运动轨迹动画(需matplotlib.animation)
|
||||
|
||||
### 5. 分析工具
|
||||
- **质心计算**:系统质心位置和轨迹
|
||||
- **守恒误差**:动量、角动量、能量守恒误差
|
||||
- **稳定性分析**:轨道稳定性评估
|
||||
- **距离分析**:质点间距离变化
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 基本使用
|
||||
```python
|
||||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig
|
||||
|
||||
# 创建质点
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1,0,0], velocity=[0,1,0]),
|
||||
Particle(mass=1.0, position=[-1,0,0], velocity=[0,-1,0]),
|
||||
Particle(mass=0.1, position=[0,1,0], velocity=[-1,0,0])
|
||||
]
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟运动
|
||||
solver.simulate(total_time=10.0)
|
||||
|
||||
# 获取结果
|
||||
trajectories = solver.get_trajectories()
|
||||
```
|
||||
|
||||
### 使用预置配置
|
||||
```python
|
||||
# 8字形轨道
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 拉格朗日点L4
|
||||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
|
||||
# 随机系统
|
||||
particles = ThreeBodyConfig.create_random_config()
|
||||
```
|
||||
|
||||
### 可视化
|
||||
```python
|
||||
from three_body_problem import ThreeBodyVisualizer
|
||||
|
||||
visualizer = ThreeBodyVisualizer()
|
||||
visualizer.plot_trajectories(solver)
|
||||
visualizer.show()
|
||||
```
|
||||
|
||||
## 物理特性
|
||||
|
||||
### 守恒定律
|
||||
1. **动量守恒**:系统总动量保持不变
|
||||
2. **角动量守恒**:系统总角动量保持不变
|
||||
3. **能量守恒**:系统总能量(动能+势能)保持不变
|
||||
|
||||
### 数值精度
|
||||
- **时间步长**:默认0.001年,可根据需要调整
|
||||
- **积分方法**:四阶龙格-库塔法,局部截断误差O(h⁵)
|
||||
- **能量误差**:典型值小于1e-5(相对误差)
|
||||
|
||||
### 稳定性条件
|
||||
1. **时间步长选择**:$\Delta t < \frac{0.01}{\sqrt{G\rho}}$,其中$\rho$为密度
|
||||
2. **近距离处理**:避免质点间距离过小(<1e-10 AU)
|
||||
3. **数值稳定性**:使用双精度浮点数计算
|
||||
|
||||
## 示例应用
|
||||
|
||||
### 1. 8字形轨道研究
|
||||
- 验证著名的稳定三体解
|
||||
- 分析轨道周期性和对称性
|
||||
- 测试数值方法的长期稳定性
|
||||
|
||||
### 2. 拉格朗日点稳定性
|
||||
- 验证L4和L5点的稳定性
|
||||
- 分析小质量质点在拉格朗日点的运动
|
||||
- 研究扰动对稳定性的影响
|
||||
|
||||
### 3. 混沌系统研究
|
||||
- 探索三体问题的混沌特性
|
||||
- 分析对初始条件的敏感性
|
||||
- 研究轨道长期演化
|
||||
|
||||
### 4. 教学演示
|
||||
- 天体力学教学工具
|
||||
- 数值方法教学示例
|
||||
- 物理守恒定律验证
|
||||
|
||||
## 性能优化
|
||||
|
||||
### 计算复杂度
|
||||
- **每步计算**:O(9)次距离计算(3个质点×3对相互作用)
|
||||
- **内存使用**:O(6N)存储轨迹,N为步数
|
||||
- **时间消耗**:与模拟时间和时间步长成线性关系
|
||||
|
||||
### 优化建议
|
||||
1. **减少输出频率**:仅保存关键时间点的轨迹
|
||||
2. **使用较小时间步长**:提高精度但增加计算量
|
||||
3. **并行计算**:可扩展为多线程计算
|
||||
4. **GPU加速**:使用CUDA或OpenCL加速计算
|
||||
|
||||
## 扩展方向
|
||||
|
||||
### 1. 算法改进
|
||||
- 实现辛积分器(Symplectic Integrator)
|
||||
- 添加自适应时间步长
|
||||
- 实现更高阶积分方法
|
||||
|
||||
### 2. 物理扩展
|
||||
- 添加相对论修正
|
||||
- 考虑潮汐效应
|
||||
- 加入辐射阻尼
|
||||
|
||||
### 3. 功能增强
|
||||
- 支持N体问题(N>3)
|
||||
- 添加碰撞检测和处理
|
||||
- 实现轨道参数计算(半长轴、偏心率等)
|
||||
|
||||
### 4. 可视化改进
|
||||
- 实时交互式可视化
|
||||
- Web界面支持
|
||||
- 3D WebGL渲染
|
||||
|
||||
## 测试验证
|
||||
|
||||
### 单元测试
|
||||
- 质点类功能测试
|
||||
- 求解器正确性测试
|
||||
- 守恒定律验证测试
|
||||
- 数值精度测试
|
||||
|
||||
### 物理验证
|
||||
- 二体问题极限测试
|
||||
- 开普勒轨道验证
|
||||
- 能量守恒长期测试
|
||||
- 动量守恒验证
|
||||
|
||||
## 参考文献
|
||||
|
||||
1. **经典三体问题**
|
||||
- Poincaré, H. (1890). "Sur le problème des trois corps et les équations de la dynamique"
|
||||
- Chenciner, A., & Montgomery, R. (2000). "A remarkable periodic solution of the three-body problem in the case of equal masses"
|
||||
|
||||
2. **数值方法**
|
||||
- Hairer, E., Nørsett, S. P., & Wanner, G. (1993). "Solving Ordinary Differential Equations I"
|
||||
- Press, W. H., et al. (2007). "Numerical Recipes: The Art of Scientific Computing"
|
||||
|
||||
3. **天体力学**
|
||||
- Murray, C. D., & Dermott, S. F. (1999). "Solar System Dynamics"
|
||||
- Goldstein, H., Poole, C., & Safko, J. (2002). "Classical Mechanics"
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License - 详见LICENSE文件
|
||||
|
||||
|
||||
|
||||
|
||||
## 版本历史
|
||||
|
||||
### v1.0.0 (2024)
|
||||
- 初始版本发布
|
||||
- 实现RK4数值积分器
|
||||
- 提供多种初始条件配置
|
||||
- 完整的可视化功能
|
||||
- 包含测试和示例
|
||||
|
||||
### 未来版本计划
|
||||
- v1.1.0:添加辛积分器
|
||||
- v1.2.0:支持N体问题
|
||||
- v1.3.0:Web界面和实时可视化
|
||||
- v2.0.0:GPU加速和并行计算
|
||||
|
||||
## 致谢
|
||||
|
||||
感谢以下开源项目:
|
||||
- NumPy:数值计算基础
|
||||
- Matplotlib:科学可视化
|
||||
- SciPy:科学计算工具
|
||||
|
||||
## 引用
|
||||
|
||||
如果您在研究中使用了此代码,请引用:
|
||||
|
||||
```
|
||||
@software{three_body_solver_2024,
|
||||
author = {ThreeBodyProblem Team},
|
||||
title = {Three-Body Problem Solver: A pure Python implementation},
|
||||
year = {2024},
|
||||
url = {https://github.com/dison0331/three-body-problem}
|
||||
}
|
||||
```
|
||||
81
simple_test.py
Normal file
81
simple_test.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单测试三体问题求解器
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加当前目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
print("三体问题求解器 - 简单测试")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 尝试导入
|
||||
print("1. 导入模块...")
|
||||
from three_body_problem.particle import Particle
|
||||
from three_body_problem.config import ThreeBodyConfig
|
||||
print(" ✓ 导入成功")
|
||||
|
||||
# 测试创建质点
|
||||
print("\n2. 测试质点创建...")
|
||||
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
|
||||
print(f" 名称: {p.name}")
|
||||
print(f" 质量: {p.mass}")
|
||||
print(f" 位置: {p.position}")
|
||||
print(f" 速度: {p.velocity}")
|
||||
print(" ✓ 质点创建成功")
|
||||
|
||||
# 测试能量计算
|
||||
energy = p.get_energy()
|
||||
print(f" 动能: {energy:.6f}")
|
||||
print(" ✓ 能量计算成功")
|
||||
|
||||
# 测试更新
|
||||
print("\n3. 测试质点更新...")
|
||||
p.update([2.0, 3.0, 4.0], [0.2, 0.3, 0.4])
|
||||
print(f" 新位置: {p.position}")
|
||||
print(f" 新速度: {p.velocity}")
|
||||
print(f" 历史记录: {len(p.position_history)} 个位置点")
|
||||
print(" ✓ 质点更新成功")
|
||||
|
||||
# 测试配置创建
|
||||
print("\n4. 测试配置创建...")
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
print(f" 创建了 {len(particles)} 个质点")
|
||||
for i, particle in enumerate(particles):
|
||||
print(f" 质点{i+1}: {particle.name}, 质量: {particle.mass:.3f}")
|
||||
print(" ✓ 配置创建成功")
|
||||
|
||||
# 测试随机配置
|
||||
print("\n5. 测试随机配置...")
|
||||
random_particles = ThreeBodyConfig.create_random_config()
|
||||
print(f" 创建了 {len(random_particles)} 个随机质点")
|
||||
total_mass = sum(p.mass for p in random_particles)
|
||||
print(f" 总质量: {total_mass:.3f}")
|
||||
print(" ✓ 随机配置成功")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有基本测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
# 显示下一步操作
|
||||
print("\n下一步:")
|
||||
print("1. 安装依赖: pip install numpy matplotlib")
|
||||
print("2. 运行完整测试: python three_body_problem/tests/test_solver.py")
|
||||
print("3. 运行示例: python three_body_problem/run_example.py")
|
||||
print("4. 查看文档: 阅读 three_body_problem/README.md")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\n✗ 导入错误: {e}")
|
||||
print("\n请确保:")
|
||||
print("1. 在项目根目录运行此测试")
|
||||
print("2. 安装了必要的依赖: pip install numpy matplotlib")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
94
test_import.py
Normal file
94
test_import.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试三体问题求解器导入和基本功能
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加当前目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
print("测试三体问题求解器导入...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 测试导入
|
||||
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
print("✓ 成功导入核心模块")
|
||||
|
||||
# 测试创建质点
|
||||
p = Particle(mass=1.0, position=[1, 0, 0], velocity=[0, 1, 0], name='Test Star')
|
||||
print(f"✓ 成功创建质点: {p.name}, 质量: {p.mass}, 位置: {p.position}")
|
||||
|
||||
# 测试能量计算
|
||||
energy = p.get_energy()
|
||||
print(f"✓ 质点动能: {energy:.6f}")
|
||||
|
||||
# 测试创建配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
print(f"✓ 成功创建8字形轨道配置: {len(particles)}个质点")
|
||||
for i, particle in enumerate(particles):
|
||||
print(f" 质点{i+1}: {particle.name}, 质量: {particle.mass:.3f}")
|
||||
|
||||
# 测试创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
print(f"✓ 成功创建求解器,时间步长: {solver.dt}")
|
||||
|
||||
# 测试单步积分
|
||||
initial_positions = [particle.position.copy() for particle in particles]
|
||||
new_particles = solver.step()
|
||||
print(f"✓ 单步积分完成,时间: {solver.time:.4f}年")
|
||||
|
||||
# 检查位置是否变化
|
||||
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
|
||||
moved = not all(abs(old_pos[j] - new_particle.position[j]) < 1e-10 for j in range(3))
|
||||
print(f" 质点{i+1} 位置变化: {'是' if moved else '否'}")
|
||||
|
||||
# 测试质心计算
|
||||
com = solver.get_center_of_mass()
|
||||
print(f"✓ 系统质心: [{com[0]:.6f}, {com[1]:.6f}, {com[2]:.6f}]")
|
||||
|
||||
# 测试能量计算
|
||||
energy = solver._calculate_total_energy()
|
||||
print(f"✓ 系统总能量: {energy:.6e}")
|
||||
|
||||
# 测试守恒误差计算
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f"✓ 守恒定律误差:")
|
||||
print(f" 动量误差: {momentum_error:.2e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.2e}")
|
||||
print(f" 能量相对误差: {energy_error:.2e}")
|
||||
|
||||
# 测试配置管理
|
||||
print("\n测试配置管理...")
|
||||
random_particles = ThreeBodyConfig.create_random_config()
|
||||
print(f"✓ 成功创建随机配置: {len(random_particles)}个质点")
|
||||
|
||||
lagrange_particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
print(f"✓ 成功创建拉格朗日点L4配置: {len(lagrange_particles)}个质点")
|
||||
|
||||
# 测试可视化器创建
|
||||
visualizer = ThreeBodyVisualizer()
|
||||
print("✓ 成功创建可视化器")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试通过! 三体问题求解器工作正常。")
|
||||
print("=" * 60)
|
||||
|
||||
# 显示使用示例
|
||||
print("\n使用示例:")
|
||||
print("1. 运行简单示例: python three_body_problem/run_example.py")
|
||||
print("2. 运行8字形轨道: python three_body_problem/examples/figure8.py")
|
||||
print("3. 运行拉格朗日点示例: python three_body_problem/examples/lagrange.py")
|
||||
print("4. 运行随机系统示例: python three_body_problem/examples/random.py")
|
||||
print("5. 运行测试: python three_body_problem/tests/test_solver.py")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"✗ 导入失败: {e}")
|
||||
print("请确保在项目根目录下运行此测试")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
41
three_body_problem/__init__.py
Normal file
41
three_body_problem/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
三体问题求解器 - 纯Python方案
|
||||
|
||||
一个用于模拟和可视化三体问题的Python库。
|
||||
使用四阶龙格-库塔法数值求解牛顿引力下的三体运动。
|
||||
|
||||
主要组件:
|
||||
- ThreeBodySolver: 主求解器类
|
||||
- Particle: 质点类
|
||||
- RK4Integrator: 数值积分器
|
||||
- ThreeBodyVisualizer: 可视化工具
|
||||
- ThreeBodyConfig: 配置管理
|
||||
|
||||
示例用法:
|
||||
>>> from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig
|
||||
>>> particles = ThreeBodyConfig.create_figure8_config()
|
||||
>>> solver = ThreeBodySolver(particles, dt=0.001)
|
||||
>>> solver.simulate(total_time=10.0)
|
||||
>>> from three_body_problem import ThreeBodyVisualizer
|
||||
>>> visualizer = ThreeBodyVisualizer()
|
||||
>>> visualizer.plot_trajectories(solver)
|
||||
>>> visualizer.show()
|
||||
"""
|
||||
|
||||
from .particle import Particle
|
||||
from .integrator import RK4Integrator
|
||||
from .solver import ThreeBodySolver
|
||||
from .visualizer import ThreeBodyVisualizer
|
||||
from .config import ThreeBodyConfig
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "ThreeBodyProblem Team"
|
||||
__email__ = "threebody@example.com"
|
||||
|
||||
__all__ = [
|
||||
"Particle",
|
||||
"RK4Integrator",
|
||||
"ThreeBodySolver",
|
||||
"ThreeBodyVisualizer",
|
||||
"ThreeBodyConfig"
|
||||
]
|
||||
BIN
three_body_problem/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
three_body_problem/__pycache__/config.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
three_body_problem/__pycache__/integrator.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/integrator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
three_body_problem/__pycache__/particle.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/particle.cpython-312.pyc
Normal file
Binary file not shown.
BIN
three_body_problem/__pycache__/solver.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/solver.cpython-312.pyc
Normal file
Binary file not shown.
BIN
three_body_problem/__pycache__/visualizer.cpython-312.pyc
Normal file
BIN
three_body_problem/__pycache__/visualizer.cpython-312.pyc
Normal file
Binary file not shown.
280
three_body_problem/config.py
Normal file
280
three_body_problem/config.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
三体问题配置管理模块
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from .particle import Particle
|
||||
|
||||
|
||||
class ThreeBodyConfig:
|
||||
"""三体问题配置类"""
|
||||
|
||||
@staticmethod
|
||||
def create_figure8_config() -> List[Particle]:
|
||||
"""
|
||||
创建8字形轨道配置(著名的稳定三体轨道)
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
# 8字形轨道的初始条件(等质量)
|
||||
m = 1.0 # 质量
|
||||
|
||||
# 位置 (Chenciner & Montgomery, 2000)
|
||||
r1 = np.array([0.97000436, -0.24308753, 0.0])
|
||||
r2 = np.array([-0.97000436, 0.24308753, 0.0])
|
||||
r3 = np.array([0.0, 0.0, 0.0])
|
||||
|
||||
# 速度
|
||||
v1 = np.array([0.466203685, 0.43236573, 0.0])
|
||||
v2 = np.array([0.466203685, 0.43236573, 0.0])
|
||||
v3 = np.array([-0.93240737, -0.86473146, 0.0])
|
||||
|
||||
particles = [
|
||||
Particle(mass=m, position=r1, velocity=v1, name="Star A", color="red"),
|
||||
Particle(mass=m, position=r2, velocity=v2, name="Star B", color="green"),
|
||||
Particle(mass=m, position=r3, velocity=v3, name="Star C", color="blue")
|
||||
]
|
||||
|
||||
return particles
|
||||
|
||||
@staticmethod
|
||||
def create_lagrange_point_config(lagrange_point: int = 4) -> List[Particle]:
|
||||
"""
|
||||
创建拉格朗日点配置
|
||||
|
||||
参数:
|
||||
lagrange_point: 拉格朗日点编号 (4=L4, 5=L5)
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
if lagrange_point not in [4, 5]:
|
||||
raise ValueError("lagrange_point 必须是 4 (L4) 或 5 (L5)")
|
||||
|
||||
# 大质量天体(类似太阳)
|
||||
m_sun = 1.0
|
||||
# 小质量天体(类似地球)
|
||||
m_earth = 3e-6 # 地球质量/太阳质量
|
||||
# 测试质点(类似小行星)
|
||||
m_test = 1e-8
|
||||
|
||||
# 大质量天体在原点
|
||||
r_sun = np.array([0.0, 0.0, 0.0])
|
||||
v_sun = np.array([0.0, 0.0, 0.0])
|
||||
|
||||
# 小质量天体在圆形轨道上(1 AU距离)
|
||||
r_earth = np.array([1.0, 0.0, 0.0])
|
||||
# 圆形轨道速度:v = sqrt(G*M/r)
|
||||
v_earth = np.array([0.0, 2*np.pi, 0.0]) # 2π AU/年
|
||||
|
||||
# 拉格朗日点位置(等边三角形)
|
||||
if lagrange_point == 4: # L4
|
||||
r_test = np.array([0.5, np.sqrt(3)/2, 0.0])
|
||||
else: # L5
|
||||
r_test = np.array([0.5, -np.sqrt(3)/2, 0.0])
|
||||
|
||||
# 测试质点的速度(与地球相同角速度)
|
||||
v_test = np.array([-np.sqrt(3)/2 * 2*np.pi, 0.5 * 2*np.pi, 0.0])
|
||||
|
||||
particles = [
|
||||
Particle(mass=m_sun, position=r_sun, velocity=v_sun, name="Sun", color="yellow"),
|
||||
Particle(mass=m_earth, position=r_earth, velocity=v_earth, name="Earth", color="blue"),
|
||||
Particle(mass=m_test, position=r_test, velocity=v_test, name="Test", color="gray")
|
||||
]
|
||||
|
||||
return particles
|
||||
|
||||
@staticmethod
|
||||
def create_random_config(masses: Optional[List[float]] = None,
|
||||
position_range: float = 2.0,
|
||||
velocity_scale: float = 1.0) -> List[Particle]:
|
||||
"""
|
||||
创建随机初始条件配置
|
||||
|
||||
参数:
|
||||
masses: 质量列表(如果为None则使用随机质量)
|
||||
position_range: 位置范围(±position_range)
|
||||
velocity_scale: 速度缩放因子
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
if masses is None:
|
||||
# 随机质量(在0.5到2.0之间)
|
||||
masses = np.random.uniform(0.5, 2.0, 3)
|
||||
|
||||
if len(masses) != 3:
|
||||
raise ValueError("需要恰好3个质量值")
|
||||
|
||||
particles = []
|
||||
colors = ['red', 'green', 'blue']
|
||||
names = ['Star A', 'Star B', 'Star C']
|
||||
|
||||
for i in range(3):
|
||||
# 随机位置
|
||||
position = np.random.uniform(-position_range, position_range, 3)
|
||||
|
||||
# 随机速度(确保系统总动量接近零)
|
||||
velocity = np.random.uniform(-velocity_scale, velocity_scale, 3)
|
||||
|
||||
particles.append(
|
||||
Particle(mass=masses[i], position=position, velocity=velocity,
|
||||
name=names[i], color=colors[i])
|
||||
)
|
||||
|
||||
# 调整速度使系统总动量接近零
|
||||
total_momentum = sum(p.mass * p.velocity for p in particles)
|
||||
total_mass = sum(p.mass for p in particles)
|
||||
|
||||
for p in particles:
|
||||
p.velocity -= total_momentum / total_mass
|
||||
|
||||
return particles
|
||||
|
||||
@staticmethod
|
||||
def create_binary_star_config() -> List[Particle]:
|
||||
"""
|
||||
创建双星系统+测试质点配置
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
# 双星质量
|
||||
m1 = 1.0
|
||||
m2 = 0.8
|
||||
|
||||
# 双星位置(在椭圆轨道上)
|
||||
# 半长轴
|
||||
a = 1.0
|
||||
# 偏心率
|
||||
e = 0.3
|
||||
|
||||
# 质心在原点
|
||||
r1 = np.array([-m2/(m1+m2) * a * (1+e), 0.0, 0.0])
|
||||
r2 = np.array([m1/(m1+m2) * a * (1+e), 0.0, 0.0])
|
||||
|
||||
# 计算轨道速度(简化圆形轨道近似)
|
||||
# 对于椭圆轨道,速度更复杂,这里使用简化
|
||||
orbital_speed = np.sqrt(4*np.pi**2 * (m1+m2) / (2*a))
|
||||
v1 = np.array([0.0, orbital_speed * m2/(m1+m2), 0.0])
|
||||
v2 = np.array([0.0, -orbital_speed * m1/(m1+m2), 0.0])
|
||||
|
||||
# 测试质点(小质量)
|
||||
m_test = 0.01
|
||||
r_test = np.array([0.0, 2.0, 0.0])
|
||||
v_test = np.array([0.5, 0.0, 0.0])
|
||||
|
||||
particles = [
|
||||
Particle(mass=m1, position=r1, velocity=v1, name="Primary", color="red"),
|
||||
Particle(mass=m2, position=r2, velocity=v2, name="Secondary", color="green"),
|
||||
Particle(mass=m_test, position=r_test, velocity=v_test, name="Test", color="blue")
|
||||
]
|
||||
|
||||
return particles
|
||||
|
||||
@staticmethod
|
||||
def create_custom_config(config_dict: Dict[str, Any]) -> List[Particle]:
|
||||
"""
|
||||
从字典创建自定义配置
|
||||
|
||||
参数:
|
||||
config_dict: 包含配置信息的字典
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
particles = []
|
||||
|
||||
for i in range(3):
|
||||
key = f"particle_{i+1}"
|
||||
if key not in config_dict:
|
||||
raise ValueError(f"配置中缺少 {key}")
|
||||
|
||||
p_config = config_dict[key]
|
||||
|
||||
particle = Particle(
|
||||
mass=p_config.get('mass', 1.0),
|
||||
position=np.array(p_config.get('position', [0.0, 0.0, 0.0])),
|
||||
velocity=np.array(p_config.get('velocity', [0.0, 0.0, 0.0])),
|
||||
name=p_config.get('name', f"Particle {i+1}"),
|
||||
color=p_config.get('color', None)
|
||||
)
|
||||
|
||||
particles.append(particle)
|
||||
|
||||
return particles
|
||||
|
||||
@staticmethod
|
||||
def save_config(particles: List[Particle], filename: str):
|
||||
"""
|
||||
保存配置到文件
|
||||
|
||||
参数:
|
||||
particles: 质点列表
|
||||
filename: 文件名
|
||||
"""
|
||||
config_dict = {}
|
||||
|
||||
for i, p in enumerate(particles):
|
||||
config_dict[f"particle_{i+1}"] = {
|
||||
'mass': float(p.mass),
|
||||
'position': p.position.tolist(),
|
||||
'velocity': p.velocity.tolist(),
|
||||
'name': p.name,
|
||||
'color': p.color
|
||||
}
|
||||
|
||||
import json
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
print(f"配置已保存到: {filename}")
|
||||
|
||||
@staticmethod
|
||||
def load_config(filename: str) -> List[Particle]:
|
||||
"""
|
||||
从文件加载配置
|
||||
|
||||
参数:
|
||||
filename: 文件名
|
||||
|
||||
返回:
|
||||
三个质点的列表
|
||||
"""
|
||||
import json
|
||||
with open(filename, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return ThreeBodyConfig.create_custom_config(config_dict)
|
||||
|
||||
@staticmethod
|
||||
def print_config_summary(particles: List[Particle]):
|
||||
"""打印配置摘要"""
|
||||
print("=" * 60)
|
||||
print("三体问题配置摘要")
|
||||
print("=" * 60)
|
||||
|
||||
total_mass = 0.0
|
||||
total_momentum = np.zeros(3)
|
||||
total_angular_momentum = np.zeros(3)
|
||||
|
||||
for i, p in enumerate(particles):
|
||||
print(f"\n质点 {i+1} ({p.name}):")
|
||||
print(f" 质量: {p.mass:.6f} M_sun")
|
||||
print(f" 位置: [{p.position[0]:.6f}, {p.position[1]:.6f}, {p.position[2]:.6f}] AU")
|
||||
print(f" 速度: [{p.velocity[0]:.6f}, {p.velocity[1]:.6f}, {p.velocity[2]:.6f}] AU/yr")
|
||||
|
||||
total_mass += p.mass
|
||||
total_momentum += p.mass * p.velocity
|
||||
angular_momentum = np.cross(p.position, p.mass * p.velocity)
|
||||
total_angular_momentum += angular_momentum
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("系统总质量: {:.6f} M_sun".format(total_mass))
|
||||
print("系统总动量: [{:.6e}, {:.6e}, {:.6e}]".format(
|
||||
total_momentum[0], total_momentum[1], total_momentum[2]))
|
||||
print("系统总角动量: [{:.6e}, {:.6e}, {:.6e}]".format(
|
||||
total_angular_momentum[0], total_angular_momentum[1], total_angular_momentum[2]))
|
||||
print("=" * 60)
|
||||
299
three_body_problem/demo.py
Normal file
299
three_body_problem/demo.py
Normal file
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
三体问题求解器演示脚本
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
|
||||
def demo_basic_usage():
|
||||
"""演示基本用法"""
|
||||
print("=" * 60)
|
||||
print("三体问题求解器 - 基本用法演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 创建自定义三体系统
|
||||
print("\n1. 创建自定义三体系统...")
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0], name="Star A", color="red"),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0], name="Star B", color="green"),
|
||||
Particle(mass=0.1, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0], name="Star C", color="blue")
|
||||
]
|
||||
|
||||
# 2. 创建求解器
|
||||
print("2. 创建求解器...")
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 3. 模拟5年
|
||||
print("3. 模拟5年运动...")
|
||||
solver.simulate(total_time=5.0, progress_interval=1000)
|
||||
|
||||
# 4. 分析结果
|
||||
print("\n4. 分析模拟结果...")
|
||||
trajectories = solver.get_trajectories()
|
||||
print(f" 轨迹点数: {len(trajectories[0])}")
|
||||
|
||||
com = solver.get_center_of_mass()
|
||||
print(f" 系统质心: [{com[0]:.3f}, {com[1]:.3f}, {com[2]:.3f}] AU")
|
||||
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f" 动量误差: {momentum_error:.2e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.2e}")
|
||||
print(f" 能量相对误差: {energy_error:.2e}")
|
||||
|
||||
# 5. 可视化
|
||||
print("\n5. 生成可视化图形...")
|
||||
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
|
||||
|
||||
# 创建3D轨迹图
|
||||
visualizer.create_3d_plot()
|
||||
visualizer.plot_trajectories(solver, title="自定义三体系统")
|
||||
|
||||
# 保存图形
|
||||
visualizer.save_figure("demo_basic_3d.png", dpi=300)
|
||||
|
||||
# 创建2D投影图
|
||||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
|
||||
fig.savefig("demo_basic_2d.png", dpi=300, bbox_inches='tight')
|
||||
|
||||
print("\n图形已保存:")
|
||||
print(" - demo_basic_3d.png (3D轨迹)")
|
||||
print(" - demo_basic_2d.png (2D投影)")
|
||||
|
||||
# 显示图形
|
||||
visualizer.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def demo_figure8():
|
||||
"""演示8字形轨道"""
|
||||
print("\n" + "=" * 60)
|
||||
print("8字形轨道演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用预置配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 打印配置信息
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟10年
|
||||
print("\n模拟10年8字形轨道...")
|
||||
solver.simulate(total_time=10.0, progress_interval=2000)
|
||||
|
||||
# 可视化
|
||||
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
|
||||
visualizer.create_3d_plot()
|
||||
visualizer.plot_trajectories(solver, title="8字形三体轨道")
|
||||
visualizer.save_figure("demo_figure8.png", dpi=300)
|
||||
|
||||
print("\n图形已保存: demo_figure8.png")
|
||||
visualizer.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def demo_lagrange_points():
|
||||
"""演示拉格朗日点"""
|
||||
print("\n" + "=" * 60)
|
||||
print("拉格朗日点演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建L4点配置
|
||||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
|
||||
# 打印配置信息
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟50年
|
||||
print("\n模拟50年拉格朗日点L4稳定性...")
|
||||
solver.simulate(total_time=50.0, progress_interval=10000)
|
||||
|
||||
# 分析测试质点稳定性
|
||||
test_particle = solver.particles[2]
|
||||
trajectory = test_particle.get_trajectory()
|
||||
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
|
||||
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
|
||||
|
||||
print(f"\n测试质点稳定性分析:")
|
||||
print(f" 初始距离L4点: {distances[0]:.6f} AU")
|
||||
print(f" 最终距离L4点: {distances[-1]:.6f} AU")
|
||||
print(f" 最大距离偏差: {np.max(distances):.6f} AU")
|
||||
print(f" 平均距离偏差: {np.mean(distances):.6f} AU")
|
||||
|
||||
# 可视化
|
||||
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
|
||||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="拉格朗日点L4稳定性")
|
||||
fig.savefig("demo_lagrange_l4.png", dpi=300, bbox_inches='tight')
|
||||
|
||||
print("\n图形已保存: demo_lagrange_l4.png")
|
||||
plt.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def demo_random_systems():
|
||||
"""演示随机系统"""
|
||||
print("\n" + "=" * 60)
|
||||
print("随机三体系统演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建3个不同的随机系统
|
||||
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
||||
|
||||
for i, seed in enumerate([42, 123, 456]):
|
||||
np.random.seed(seed)
|
||||
|
||||
# 创建随机配置
|
||||
particles = ThreeBodyConfig.create_random_config(
|
||||
position_range=1.5,
|
||||
velocity_scale=2.0
|
||||
)
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟5年
|
||||
solver.simulate(total_time=5.0, progress_interval=2000)
|
||||
|
||||
# 绘制轨迹
|
||||
trajectories = solver.get_trajectories()
|
||||
colors = ['red', 'green', 'blue']
|
||||
|
||||
ax = axes[i]
|
||||
for j, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[j % len(colors)]
|
||||
ax.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5)
|
||||
ax.scatter(traj[-1, 0], traj[-1, 1], color=color, s=50, edgecolors='black', linewidth=1)
|
||||
|
||||
ax.set_xlabel('X (AU)', fontsize=10)
|
||||
ax.set_ylabel('Y (AU)', fontsize=10)
|
||||
ax.set_title(f'随机系统 #{i+1} (种子: {seed})', fontsize=12, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_aspect('equal', adjustable='box')
|
||||
|
||||
plt.suptitle('随机三体系统示例', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
plt.savefig("demo_random_systems.png", dpi=300, bbox_inches='tight')
|
||||
|
||||
print("\n图形已保存: demo_random_systems.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
def demo_conservation_laws():
|
||||
"""演示守恒定律"""
|
||||
print("\n" + "=" * 60)
|
||||
print("守恒定律验证演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用8字形轨道(理论上应该守恒)
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 测试不同时间步长下的守恒性
|
||||
time_steps = [0.01, 0.005, 0.001, 0.0005]
|
||||
total_time = 1.0
|
||||
|
||||
results = []
|
||||
|
||||
for dt in time_steps:
|
||||
print(f"\n测试时间步长: {dt}")
|
||||
|
||||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||||
solver.simulate(total_time=total_time, progress_interval=1000)
|
||||
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
|
||||
results.append({
|
||||
'dt': dt,
|
||||
'momentum_error': momentum_error,
|
||||
'angular_momentum_error': angular_momentum_error,
|
||||
'energy_error': energy_error
|
||||
})
|
||||
|
||||
print(f" 动量误差: {momentum_error:.2e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.2e}")
|
||||
print(f" 能量相对误差: {energy_error:.2e}")
|
||||
|
||||
# 绘制误差图
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
|
||||
dts = [r['dt'] for r in results]
|
||||
|
||||
# 动量误差
|
||||
axes[0].loglog(dts, [r['momentum_error'] for r in results], 'o-', linewidth=2, markersize=8)
|
||||
axes[0].set_xlabel('时间步长 (年)', fontsize=12)
|
||||
axes[0].set_ylabel('动量误差', fontsize=12)
|
||||
axes[0].set_title('动量守恒', fontsize=14, fontweight='bold')
|
||||
axes[0].grid(True, alpha=0.3, which='both')
|
||||
|
||||
# 角动量误差
|
||||
axes[1].loglog(dts, [r['angular_momentum_error'] for r in results], 's-', linewidth=2, markersize=8, color='green')
|
||||
axes[1].set_xlabel('时间步长 (年)', fontsize=12)
|
||||
axes[1].set_ylabel('角动量误差', fontsize=12)
|
||||
axes[1].set_title('角动量守恒', fontsize=14, fontweight='bold')
|
||||
axes[1].grid(True, alpha=0.3, which='both')
|
||||
|
||||
# 能量误差
|
||||
axes[2].loglog(dts, [r['energy_error'] for r in results], '^-', linewidth=2, markersize=8, color='red')
|
||||
axes[2].set_xlabel('时间步长 (年)', fontsize=12)
|
||||
axes[2].set_ylabel('能量相对误差', fontsize=12)
|
||||
axes[2].set_title('能量守恒', fontsize=14, fontweight='bold')
|
||||
axes[2].grid(True, alpha=0.3, which='both')
|
||||
|
||||
# 添加参考线(四阶精度)
|
||||
ref_dt = np.array(dts)
|
||||
ref_error = 1e-4 * (ref_dt / 0.001)**4
|
||||
axes[2].loglog(ref_dt, ref_error, 'k--', alpha=0.7, label='四阶精度参考线')
|
||||
axes[2].legend()
|
||||
|
||||
plt.suptitle('守恒定律数值验证 (RK4方法)', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
plt.savefig("demo_conservation_laws.png", dpi=300, bbox_inches='tight')
|
||||
|
||||
print("\n图形已保存: demo_conservation_laws.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("三体问题求解器演示")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行所有演示
|
||||
try:
|
||||
# 演示1: 基本用法
|
||||
solver1 = demo_basic_usage()
|
||||
|
||||
# 演示2: 8字形轨道
|
||||
solver2 = demo_figure8()
|
||||
|
||||
# 演示3: 拉格朗日点
|
||||
solver3 = demo_lagrange_points()
|
||||
|
||||
# 演示4: 随机系统
|
||||
demo_random_systems()
|
||||
|
||||
# 演示5: 守恒定律
|
||||
demo_conservation_laws()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有演示完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
three_body_problem/examples/__init__.py
Normal file
16
three_body_problem/examples/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
三体问题示例模块
|
||||
"""
|
||||
|
||||
from .figure8 import run_figure8_example, analyze_figure8_stability
|
||||
from .lagrange import run_lagrange_example, compare_lagrange_points
|
||||
from .random import run_random_example, run_multiple_random_simulations
|
||||
|
||||
__all__ = [
|
||||
"run_figure8_example",
|
||||
"analyze_figure8_stability",
|
||||
"run_lagrange_example",
|
||||
"compare_lagrange_points",
|
||||
"run_random_example",
|
||||
"run_multiple_random_simulations"
|
||||
]
|
||||
203
three_body_problem/examples/figure8.py
Normal file
203
three_body_problem/examples/figure8.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
8字形轨道示例 - 著名的稳定三体轨道
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
|
||||
def run_figure8_example():
|
||||
"""运行8字形轨道示例"""
|
||||
print("=" * 60)
|
||||
print("8字形轨道示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建8字形轨道配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 打印配置摘要
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
dt = 0.001 # 时间步长(年)
|
||||
solver = ThreeBodySolver(particles, dt=dt)
|
||||
|
||||
# 模拟10年
|
||||
total_time = 10.0
|
||||
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}年")
|
||||
|
||||
solver.simulate(total_time=total_time, progress_interval=2000)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f"\n守恒定律误差:")
|
||||
print(f" 动量误差: {momentum_error:.6e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.6e}")
|
||||
print(f" 能量相对误差: {energy_error:.6e}")
|
||||
|
||||
# 可视化
|
||||
print("\n生成可视化图形...")
|
||||
visualizer = ThreeBodyVisualizer(figsize=(14, 10))
|
||||
|
||||
# 创建3D轨迹图
|
||||
plt.figure(figsize=(14, 10))
|
||||
|
||||
# 3D轨迹
|
||||
ax1 = plt.subplot(2, 2, 1, projection='3d')
|
||||
trajectories = solver.get_trajectories()
|
||||
colors = ['red', 'green', 'blue']
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
ax1.plot(traj[:, 0], traj[:, 1], traj[:, 2],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax1.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心
|
||||
com = solver.get_center_of_mass()
|
||||
ax1.scatter(com[0], com[1], com[2],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
ax1.set_xlabel('X (AU)')
|
||||
ax1.set_ylabel('Y (AU)')
|
||||
ax1.set_zlabel('Z (AU)')
|
||||
ax1.set_title('8字形轨道 - 3D视图')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# XY平面投影
|
||||
ax2 = plt.subplot(2, 2, 2)
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax2.scatter(traj[-1, 0], traj[-1, 1], color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
ax2.scatter(com[0], com[1], color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
ax2.set_xlabel('X (AU)')
|
||||
ax2.set_ylabel('Y (AU)')
|
||||
ax2.set_title('XY平面投影')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_aspect('equal', adjustable='box')
|
||||
|
||||
# XZ平面投影
|
||||
ax3 = plt.subplot(2, 2, 3)
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
ax3.plot(traj[:, 0], traj[:, 2], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax3.scatter(traj[-1, 0], traj[-1, 2], color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
ax3.scatter(com[0], com[2], color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
ax3.set_xlabel('X (AU)')
|
||||
ax3.set_ylabel('Z (AU)')
|
||||
ax3.set_title('XZ平面投影')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
ax3.set_aspect('equal', adjustable='box')
|
||||
|
||||
# YZ平面投影
|
||||
ax4 = plt.subplot(2, 2, 4)
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
ax4.plot(traj[:, 1], traj[:, 2], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax4.scatter(traj[-1, 1], traj[-1, 2], color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
ax4.scatter(com[1], com[2], color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
ax4.set_xlabel('Y (AU)')
|
||||
ax4.set_ylabel('Z (AU)')
|
||||
ax4.set_title('YZ平面投影')
|
||||
ax4.legend()
|
||||
ax4.grid(True, alpha=0.3)
|
||||
ax4.set_aspect('equal', adjustable='box')
|
||||
|
||||
plt.suptitle('8字形三体轨道', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图形
|
||||
output_file = "figure8_orbit.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n图形已保存到: {output_file}")
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def analyze_figure8_stability():
|
||||
"""分析8字形轨道的稳定性"""
|
||||
print("\n" + "=" * 60)
|
||||
print("8字形轨道稳定性分析")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建8字形轨道配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 测试不同时间步长
|
||||
time_steps = [0.01, 0.005, 0.001, 0.0005]
|
||||
total_time = 5.0
|
||||
|
||||
results = []
|
||||
|
||||
for dt in time_steps:
|
||||
print(f"\n测试时间步长: {dt}")
|
||||
|
||||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||||
solver.simulate(total_time=total_time, progress_interval=10000)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
|
||||
results.append({
|
||||
'dt': dt,
|
||||
'momentum_error': momentum_error,
|
||||
'angular_momentum_error': angular_momentum_error,
|
||||
'energy_error': energy_error
|
||||
})
|
||||
|
||||
print(f" 能量相对误差: {energy_error:.6e}")
|
||||
|
||||
# 绘制误差随步长变化
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
dts = [r['dt'] for r in results]
|
||||
energy_errors = [r['energy_error'] for r in results]
|
||||
|
||||
plt.loglog(dts, energy_errors, 'o-', linewidth=2, markersize=8)
|
||||
plt.xlabel('时间步长 (年)', fontsize=12)
|
||||
plt.ylabel('能量相对误差', fontsize=12)
|
||||
plt.title('8字形轨道数值误差分析', fontsize=14, fontweight='bold')
|
||||
plt.grid(True, alpha=0.3, which='both')
|
||||
|
||||
# 添加参考线(四阶精度)
|
||||
ref_dt = np.array(dts)
|
||||
ref_error = 1e-4 * (ref_dt / 0.001)**4
|
||||
plt.loglog(ref_dt, ref_error, 'r--', alpha=0.7, label='四阶精度参考线')
|
||||
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
|
||||
output_file = "figure8_stability_analysis.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n稳定性分析图形已保存到: {output_file}")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行主示例
|
||||
solver = run_figure8_example()
|
||||
|
||||
# 运行稳定性分析(可选)
|
||||
# analyze_figure8_stability()
|
||||
343
three_body_problem/examples/lagrange.py
Normal file
343
three_body_problem/examples/lagrange.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
拉格朗日点示例 - 三体问题的稳定点
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
|
||||
def run_lagrange_example(lagrange_point: int = 4, total_time: float = 100.0):
|
||||
"""
|
||||
运行拉格朗日点示例
|
||||
|
||||
参数:
|
||||
lagrange_point: 拉格朗日点编号 (4=L4, 5=L5)
|
||||
total_time: 总模拟时间(年)
|
||||
"""
|
||||
point_name = "L4" if lagrange_point == 4 else "L5"
|
||||
print("=" * 60)
|
||||
print(f"拉格朗日点 {point_name} 示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建拉格朗日点配置
|
||||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=lagrange_point)
|
||||
|
||||
# 打印配置摘要
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
dt = 0.001 # 时间步长(年)
|
||||
solver = ThreeBodySolver(particles, dt=dt)
|
||||
|
||||
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}年")
|
||||
print(f"模拟拉格朗日点 {point_name} 的稳定性")
|
||||
|
||||
solver.simulate(total_time=total_time, progress_interval=20000)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f"\n守恒定律误差:")
|
||||
print(f" 动量误差: {momentum_error:.6e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.6e}")
|
||||
print(f" 能量相对误差: {energy_error:.6e}")
|
||||
|
||||
# 分析测试质点的轨道稳定性
|
||||
test_particle = solver.particles[2] # 测试质点是第三个
|
||||
trajectory = test_particle.get_trajectory()
|
||||
|
||||
# 计算与L4/L5点的距离变化
|
||||
if lagrange_point == 4:
|
||||
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
|
||||
else: # L5
|
||||
lagrange_position = np.array([0.5, -np.sqrt(3)/2, 0.0])
|
||||
|
||||
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
|
||||
time_points = np.arange(len(distances)) * dt
|
||||
|
||||
print(f"\n测试质点稳定性分析:")
|
||||
print(f" 初始距离L{lagrange_point}点: {distances[0]:.6e} AU")
|
||||
print(f" 最终距离L{lagrange_point}点: {distances[-1]:.6e} AU")
|
||||
print(f" 最大距离偏差: {np.max(distances):.6e} AU")
|
||||
print(f" 平均距离偏差: {np.mean(distances):.6e} AU")
|
||||
|
||||
# 可视化
|
||||
print("\n生成可视化图形...")
|
||||
|
||||
# 创建图形
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
# 1. XY平面轨迹
|
||||
ax1 = plt.subplot(2, 3, 1)
|
||||
|
||||
# 绘制所有质点的轨迹
|
||||
trajectories = solver.get_trajectories()
|
||||
colors = ['gold', 'blue', 'gray']
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 只绘制最后一部分轨迹(更清晰)
|
||||
if len(traj) > 1000:
|
||||
traj_to_plot = traj[-1000:]
|
||||
else:
|
||||
traj_to_plot = traj
|
||||
|
||||
ax1.plot(traj_to_plot[:, 0], traj_to_plot[:, 1],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制最终位置
|
||||
ax1.scatter(traj[-1, 0], traj[-1, 1],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5, zorder=5)
|
||||
|
||||
# 绘制拉格朗日点位置
|
||||
ax1.scatter(lagrange_position[0], lagrange_position[1],
|
||||
color='red', marker='*', s=300, label=f'L{lagrange_point}点', zorder=10)
|
||||
|
||||
# 绘制等边三角形
|
||||
triangle_points = [
|
||||
[0, 0], # 太阳
|
||||
[1, 0], # 地球
|
||||
lagrange_position[:2] # L4或L5点
|
||||
]
|
||||
triangle_points.append(triangle_points[0]) # 闭合三角形
|
||||
triangle_points = np.array(triangle_points)
|
||||
ax1.plot(triangle_points[:, 0], triangle_points[:, 1],
|
||||
'k--', alpha=0.5, linewidth=1, label='等边三角形')
|
||||
|
||||
ax1.set_xlabel('X (AU)', fontsize=12)
|
||||
ax1.set_ylabel('Y (AU)', fontsize=12)
|
||||
ax1.set_title(f'拉格朗日点 {point_name} - XY平面', fontsize=14, fontweight='bold')
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
ax1.set_aspect('equal', adjustable='box')
|
||||
|
||||
# 2. 距离随时间变化
|
||||
ax2 = plt.subplot(2, 3, 2)
|
||||
ax2.plot(time_points, distances, 'b-', linewidth=2, alpha=0.8)
|
||||
ax2.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax2.set_ylabel(f'距离L{lagrange_point}点 (AU)', fontsize=12)
|
||||
ax2.set_title('测试质点轨道稳定性', fontsize=14, fontweight='bold')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# 添加平均距离线
|
||||
mean_distance = np.mean(distances)
|
||||
ax2.axhline(y=mean_distance, color='r', linestyle='--', alpha=0.7,
|
||||
label=f'平均距离: {mean_distance:.3e}')
|
||||
ax2.legend(fontsize=10)
|
||||
|
||||
# 3. 相空间图 (x vs vx)
|
||||
ax3 = plt.subplot(2, 3, 3)
|
||||
|
||||
# 计算速度(使用位置差分)
|
||||
if len(trajectory) > 1:
|
||||
dt = solver.dt
|
||||
velocities = np.gradient(trajectory, dt, axis=0)
|
||||
x_positions = trajectory[:, 0]
|
||||
x_velocities = velocities[:, 0]
|
||||
|
||||
# 使用颜色表示时间
|
||||
scatter = ax3.scatter(x_positions, x_velocities, c=time_points,
|
||||
cmap='viridis', alpha=0.7, s=20)
|
||||
plt.colorbar(scatter, ax=ax3, label='时间 (年)')
|
||||
|
||||
ax3.set_xlabel('X 位置 (AU)', fontsize=12)
|
||||
ax3.set_ylabel('X 速度 (AU/年)', fontsize=12)
|
||||
ax3.set_title('测试质点相空间 (X维度)', fontsize=14, fontweight='bold')
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# 4. 相对位置图(以地球为参考系)
|
||||
ax4 = plt.subplot(2, 3, 4)
|
||||
|
||||
# 计算相对于地球的位置
|
||||
earth_trajectory = trajectories[1] # 地球是第二个质点
|
||||
sun_trajectory = trajectories[0] # 太阳是第一个质点
|
||||
test_trajectory = trajectories[2] # 测试质点是第三个
|
||||
|
||||
# 转换为以地球为中心的坐标系
|
||||
earth_centered_sun = sun_trajectory - earth_trajectory
|
||||
earth_centered_test = test_trajectory - earth_trajectory
|
||||
|
||||
# 只绘制最后一部分
|
||||
if len(earth_centered_test) > 1000:
|
||||
earth_centered_test = earth_centered_test[-1000:]
|
||||
|
||||
ax4.plot(earth_centered_test[:, 0], earth_centered_test[:, 1],
|
||||
'gray', alpha=0.7, linewidth=1.5, label='测试质点')
|
||||
ax4.scatter(0, 0, color='blue', s=200, label='地球', edgecolors='black', linewidth=1.5)
|
||||
ax4.scatter(earth_centered_sun[-1, 0], earth_centered_sun[-1, 1],
|
||||
color='gold', s=200, label='太阳', edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制理论L4/L5点位置
|
||||
if lagrange_point == 4:
|
||||
l_point_relative = np.array([-0.5, np.sqrt(3)/2])
|
||||
else: # L5
|
||||
l_point_relative = np.array([-0.5, -np.sqrt(3)/2])
|
||||
|
||||
ax4.scatter(l_point_relative[0], l_point_relative[1],
|
||||
color='red', marker='*', s=300, label=f'L{lagrange_point}点', zorder=10)
|
||||
|
||||
ax4.set_xlabel('相对X位置 (AU)', fontsize=12)
|
||||
ax4.set_ylabel('相对Y位置 (AU)', fontsize=12)
|
||||
ax4.set_title('以地球为参考系', fontsize=14, fontweight='bold')
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
ax4.set_aspect('equal', adjustable='box')
|
||||
|
||||
# 5. 能量随时间变化(简化)
|
||||
ax5 = plt.subplot(2, 3, 5)
|
||||
|
||||
# 计算相对能量变化(简化)
|
||||
# 在实际实现中,需要记录能量历史
|
||||
time_array = np.linspace(0, total_time, len(distances))
|
||||
# 使用距离变化作为能量变化的代理
|
||||
energy_proxy = distances / distances[0]
|
||||
|
||||
ax5.plot(time_array, energy_proxy, 'g-', linewidth=2, alpha=0.8)
|
||||
ax5.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax5.set_ylabel('相对能量变化', fontsize=12)
|
||||
ax5.set_title('轨道能量变化', fontsize=14, fontweight='bold')
|
||||
ax5.grid(True, alpha=0.3)
|
||||
ax5.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='初始能量')
|
||||
ax5.legend(fontsize=10)
|
||||
|
||||
# 6. 3D视图
|
||||
ax6 = plt.subplot(2, 3, 6, projection='3d')
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 只绘制最后一部分轨迹
|
||||
if len(traj) > 1000:
|
||||
traj_to_plot = traj[-1000:]
|
||||
else:
|
||||
traj_to_plot = traj
|
||||
|
||||
ax6.plot(traj_to_plot[:, 0], traj_to_plot[:, 1], traj_to_plot[:, 2],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制最终位置
|
||||
ax6.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5, zorder=5)
|
||||
|
||||
ax6.scatter(lagrange_position[0], lagrange_position[1], lagrange_position[2],
|
||||
color='red', marker='*', s=300, label=f'L{lagrange_point}点', zorder=10)
|
||||
|
||||
ax6.set_xlabel('X (AU)', fontsize=10)
|
||||
ax6.set_ylabel('Y (AU)', fontsize=10)
|
||||
ax6.set_zlabel('Z (AU)', fontsize=10)
|
||||
ax6.set_title('3D视图', fontsize=14, fontweight='bold')
|
||||
ax6.legend(fontsize=9, loc='upper left')
|
||||
ax6.grid(True, alpha=0.3)
|
||||
|
||||
plt.suptitle(f'拉格朗日点 {point_name} 稳定性分析', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图形
|
||||
output_file = f"lagrange_point_{point_name}.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n图形已保存到: {output_file}")
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def compare_lagrange_points():
|
||||
"""比较L4和L5点的稳定性"""
|
||||
print("\n" + "=" * 60)
|
||||
print("拉格朗日点L4和L5稳定性比较")
|
||||
print("=" * 60)
|
||||
|
||||
total_time = 50.0
|
||||
dt = 0.001
|
||||
|
||||
results = []
|
||||
|
||||
for lagrange_point in [4, 5]:
|
||||
point_name = f"L{lagrange_point}"
|
||||
print(f"\n模拟 {point_name} 点...")
|
||||
|
||||
particles = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=lagrange_point)
|
||||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||||
solver.simulate(total_time=total_time, progress_interval=25000)
|
||||
|
||||
# 分析测试质点的轨道稳定性
|
||||
test_particle = solver.particles[2]
|
||||
trajectory = test_particle.get_trajectory()
|
||||
|
||||
if lagrange_point == 4:
|
||||
lagrange_position = np.array([0.5, np.sqrt(3)/2, 0.0])
|
||||
else: # L5
|
||||
lagrange_position = np.array([0.5, -np.sqrt(3)/2, 0.0])
|
||||
|
||||
distances = np.linalg.norm(trajectory - lagrange_position, axis=1)
|
||||
|
||||
results.append({
|
||||
'point': point_name,
|
||||
'max_distance': np.max(distances),
|
||||
'mean_distance': np.mean(distances),
|
||||
'std_distance': np.std(distances),
|
||||
'final_distance': distances[-1]
|
||||
})
|
||||
|
||||
print(f" {point_name} 最大距离偏差: {np.max(distances):.6e} AU")
|
||||
print(f" {point_name} 平均距离偏差: {np.mean(distances):.6e} AU")
|
||||
|
||||
# 绘制比较图
|
||||
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
||||
|
||||
points = [r['point'] for r in results]
|
||||
max_distances = [r['max_distance'] for r in results]
|
||||
mean_distances = [r['mean_distance'] for r in results]
|
||||
|
||||
x = np.arange(len(points))
|
||||
width = 0.35
|
||||
|
||||
axes[0].bar(x - width/2, max_distances, width, label='最大偏差', color='lightcoral')
|
||||
axes[0].bar(x + width/2, mean_distances, width, label='平均偏差', color='lightblue')
|
||||
axes[0].set_xlabel('拉格朗日点', fontsize=12)
|
||||
axes[0].set_ylabel('距离偏差 (AU)', fontsize=12)
|
||||
axes[0].set_title('L4和L5点稳定性比较', fontsize=14, fontweight='bold')
|
||||
axes[0].set_xticks(x)
|
||||
axes[0].set_xticklabels(points)
|
||||
axes[0].legend()
|
||||
axes[0].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 最终位置偏差
|
||||
final_distances = [r['final_distance'] for r in results]
|
||||
axes[1].bar(points, final_distances, color=['lightgreen', 'lightblue'])
|
||||
axes[1].set_xlabel('拉格朗日点', fontsize=12)
|
||||
axes[1].set_ylabel('最终距离偏差 (AU)', fontsize=12)
|
||||
axes[1].set_title('最终位置稳定性', fontsize=14, fontweight='bold')
|
||||
axes[1].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
output_file = "lagrange_points_comparison.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n比较图形已保存到: {output_file}")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行L4点示例
|
||||
print("运行拉格朗日点L4示例...")
|
||||
solver_l4 = run_lagrange_example(lagrange_point=4, total_time=50.0)
|
||||
|
||||
# 运行L5点示例
|
||||
print("\n" + "="*60)
|
||||
print("运行拉格朗日点L5示例...")
|
||||
solver_l5 = run_lagrange_example(lagrange_point=5, total_time=50.0)
|
||||
|
||||
# 比较L4和L5
|
||||
compare_lagrange_points()
|
||||
426
three_body_problem/examples/random.py
Normal file
426
three_body_problem/examples/random.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
随机初始条件示例 - 探索不同的三体系统
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from three_body_problem import ThreeBodySolver, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
|
||||
def run_random_example(seed: int = 42, total_time: float = 20.0):
|
||||
"""
|
||||
运行随机初始条件示例
|
||||
|
||||
参数:
|
||||
seed: 随机种子
|
||||
total_time: 总模拟时间(年)
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
|
||||
print("=" * 60)
|
||||
print(f"随机初始条件示例 (种子: {seed})")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建随机配置
|
||||
particles = ThreeBodyConfig.create_random_config(
|
||||
masses=None, # 使用随机质量
|
||||
position_range=2.0, # 位置范围 ±2 AU
|
||||
velocity_scale=3.0 # 速度缩放因子
|
||||
)
|
||||
|
||||
# 打印配置摘要
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
dt = 0.001 # 时间步长(年)
|
||||
solver = ThreeBodySolver(particles, dt=dt)
|
||||
|
||||
print(f"\n开始模拟,总时间: {total_time}年,时间步长: {dt}年")
|
||||
|
||||
solver.simulate(total_time=total_time, progress_interval=2000)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f"\n守恒定律误差:")
|
||||
print(f" 动量误差: {momentum_error:.6e}")
|
||||
print(f" 角动量误差: {angular_momentum_error:.6e}")
|
||||
print(f" 能量相对误差: {energy_error:.6e}")
|
||||
|
||||
# 分析系统行为
|
||||
analyze_system_behavior(solver)
|
||||
|
||||
# 可视化
|
||||
print("\n生成可视化图形...")
|
||||
visualize_random_system(solver, seed)
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def analyze_system_behavior(solver: ThreeBodySolver):
|
||||
"""分析三体系统的行为"""
|
||||
print("\n" + "-" * 40)
|
||||
print("系统行为分析")
|
||||
print("-" * 40)
|
||||
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 计算每个质点的运动范围
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
pos_range = np.ptp(traj, axis=0) # 位置范围 (max - min)
|
||||
avg_speed = np.mean(np.linalg.norm(np.gradient(traj, solver.dt, axis=0), axis=1))
|
||||
|
||||
print(f"\n质点 {i+1} ({particle.name}):")
|
||||
print(f" 质量: {particle.mass:.4f} M_sun")
|
||||
print(f" 位置范围: X={pos_range[0]:.3f}, Y={pos_range[1]:.3f}, Z={pos_range[2]:.3f} AU")
|
||||
print(f" 平均速度: {avg_speed:.3f} AU/年")
|
||||
|
||||
# 计算质点之间的最小距离
|
||||
min_distances = []
|
||||
for i in range(3):
|
||||
for j in range(i+1, 3):
|
||||
traj_i = trajectories[i]
|
||||
traj_j = trajectories[j]
|
||||
distances = np.linalg.norm(traj_i - traj_j, axis=1)
|
||||
min_dist = np.min(distances)
|
||||
min_distances.append((i, j, min_dist))
|
||||
|
||||
print(f"\n质点间最小距离:")
|
||||
for i, j, min_dist in min_distances:
|
||||
print(f" 质点{i+1}-质点{j+1}: {min_dist:.4f} AU")
|
||||
|
||||
# 检查是否有碰撞或近距离接近
|
||||
collision_threshold = 0.1 # AU
|
||||
close_encounters = [(i, j, d) for i, j, d in min_distances if d < collision_threshold]
|
||||
|
||||
if close_encounters:
|
||||
print(f"\n警告: 检测到近距离接近!")
|
||||
for i, j, d in close_encounters:
|
||||
print(f" 质点{i+1}和质点{j+1}的最小距离: {d:.4f} AU < {collision_threshold} AU")
|
||||
else:
|
||||
print(f"\n系统稳定: 所有质点间距离都大于 {collision_threshold} AU")
|
||||
|
||||
# 计算系统质心运动
|
||||
com_trajectory = []
|
||||
for t in range(len(trajectories[0])):
|
||||
com = np.zeros(3)
|
||||
total_mass = 0.0
|
||||
for i, traj in enumerate(trajectories):
|
||||
com += solver.particles[i].mass * traj[t]
|
||||
total_mass += solver.particles[i].mass
|
||||
com_trajectory.append(com / total_mass)
|
||||
|
||||
com_trajectory = np.array(com_trajectory)
|
||||
com_range = np.ptp(com_trajectory, axis=0)
|
||||
print(f"\n系统质心运动范围: X={com_range[0]:.4f}, Y={com_range[1]:.4f}, Z={com_range[2]:.4f} AU")
|
||||
|
||||
|
||||
def visualize_random_system(solver: ThreeBodySolver, seed: int):
|
||||
"""可视化随机三体系统"""
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 创建图形
|
||||
fig = plt.figure(figsize=(16, 12))
|
||||
|
||||
# 1. 3D轨迹图
|
||||
ax1 = plt.subplot(2, 3, 1, projection='3d')
|
||||
|
||||
colors = ['red', 'green', 'blue']
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
ax1.plot(traj[:, 0], traj[:, 1], traj[:, 2],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax1.scatter(traj[-1, 0], traj[-1, 1], traj[-1, 2],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心轨迹
|
||||
com_trajectory = []
|
||||
for t in range(len(trajectories[0])):
|
||||
com = np.zeros(3)
|
||||
total_mass = 0.0
|
||||
for i, traj in enumerate(trajectories):
|
||||
com += solver.particles[i].mass * traj[t]
|
||||
total_mass += solver.particles[i].mass
|
||||
com_trajectory.append(com / total_mass)
|
||||
|
||||
com_trajectory = np.array(com_trajectory)
|
||||
ax1.plot(com_trajectory[:, 0], com_trajectory[:, 1], com_trajectory[:, 2],
|
||||
'k--', alpha=0.5, linewidth=1, label='质心轨迹')
|
||||
ax1.scatter(com_trajectory[-1, 0], com_trajectory[-1, 1], com_trajectory[-1, 2],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
ax1.set_xlabel('X (AU)', fontsize=12)
|
||||
ax1.set_ylabel('Y (AU)', fontsize=12)
|
||||
ax1.set_zlabel('Z (AU)', fontsize=12)
|
||||
ax1.set_title('3D轨迹图', fontsize=14, fontweight='bold')
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 2. XY平面投影
|
||||
ax2 = plt.subplot(2, 3, 2)
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
ax2.plot(traj[:, 0], traj[:, 1], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
ax2.scatter(traj[-1, 0], traj[-1, 1], color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
ax2.plot(com_trajectory[:, 0], com_trajectory[:, 1], 'k--', alpha=0.5, linewidth=1)
|
||||
ax2.scatter(com_trajectory[-1, 0], com_trajectory[-1, 1],
|
||||
color='black', marker='x', s=200, linewidth=2)
|
||||
|
||||
ax2.set_xlabel('X (AU)', fontsize=12)
|
||||
ax2.set_ylabel('Y (AU)', fontsize=12)
|
||||
ax2.set_title('XY平面投影', fontsize=14, fontweight='bold')
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_aspect('equal', adjustable='box')
|
||||
|
||||
# 3. 距离随时间变化
|
||||
ax3 = plt.subplot(2, 3, 3)
|
||||
|
||||
time_points = np.arange(len(trajectories[0])) * solver.dt
|
||||
|
||||
# 计算所有质点对之间的距离
|
||||
distances = []
|
||||
labels = []
|
||||
for i in range(3):
|
||||
for j in range(i+1, 3):
|
||||
dist = np.linalg.norm(trajectories[i] - trajectories[j], axis=1)
|
||||
distances.append(dist)
|
||||
labels.append(f"质点{i+1}-质点{j+1}")
|
||||
|
||||
for dist, label in zip(distances, labels):
|
||||
ax3.plot(time_points, dist, linewidth=1.5, alpha=0.8, label=label)
|
||||
|
||||
ax3.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax3.set_ylabel('距离 (AU)', fontsize=12)
|
||||
ax3.set_title('质点间距离变化', fontsize=14, fontweight='bold')
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# 4. 速度大小随时间变化
|
||||
ax4 = plt.subplot(2, 3, 4)
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 计算速度大小(使用位置差分)
|
||||
if len(traj) > 1:
|
||||
velocities = np.gradient(traj, solver.dt, axis=0)
|
||||
speed = np.linalg.norm(velocities, axis=1)
|
||||
ax4.plot(time_points, speed, color=color, linewidth=1.5, alpha=0.8, label=label)
|
||||
|
||||
ax4.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax4.set_ylabel('速度大小 (AU/年)', fontsize=12)
|
||||
ax4.set_title('质点速度变化', fontsize=14, fontweight='bold')
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3)
|
||||
|
||||
# 5. 能量分布饼图
|
||||
ax5 = plt.subplot(2, 3, 5)
|
||||
|
||||
# 计算每个质点的动能和势能
|
||||
kinetic_energies = []
|
||||
potential_energies = []
|
||||
|
||||
for i, particle in enumerate(solver.particles):
|
||||
# 动能
|
||||
v_squared = np.sum(particle.velocity**2)
|
||||
kinetic_energy = 0.5 * particle.mass * v_squared
|
||||
kinetic_energies.append(kinetic_energy)
|
||||
|
||||
# 势能(与其他质点的相互作用)
|
||||
potential_energy = 0.0
|
||||
for j, other in enumerate(solver.particles):
|
||||
if i != j:
|
||||
r_vec = other.position - particle.position
|
||||
r = np.linalg.norm(r_vec)
|
||||
if r > 1e-10:
|
||||
potential_energy -= ThreeBodySolver.G * particle.mass * other.mass / r
|
||||
potential_energies.append(potential_energy)
|
||||
|
||||
# 只考虑势能的一半(每对质点计算了两次)
|
||||
potential_energies = [pe/2 for pe in potential_energies]
|
||||
|
||||
labels = [f"质点{i+1}" for i in range(3)]
|
||||
colors_pie = ['lightcoral', 'lightgreen', 'lightblue']
|
||||
|
||||
ax5.pie(kinetic_energies, labels=labels, autopct='%1.1f%%',
|
||||
colors=colors_pie, startangle=90)
|
||||
ax5.set_title('动能分布', fontsize=14, fontweight='bold')
|
||||
|
||||
# 6. 相空间图(所有质点)
|
||||
ax6 = plt.subplot(2, 3, 6)
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
if len(traj) > 1:
|
||||
velocities = np.gradient(traj, solver.dt, axis=0)
|
||||
x_positions = traj[:, 0]
|
||||
x_velocities = velocities[:, 0]
|
||||
|
||||
# 使用颜色表示时间
|
||||
scatter = ax6.scatter(x_positions, x_velocities, c=time_points,
|
||||
cmap='viridis', alpha=0.6, s=10, label=label)
|
||||
|
||||
plt.colorbar(scatter, ax=ax6, label='时间 (年)')
|
||||
ax6.set_xlabel('X 位置 (AU)', fontsize=12)
|
||||
ax6.set_ylabel('X 速度 (AU/年)', fontsize=12)
|
||||
ax6.set_title('相空间图 (X维度)', fontsize=14, fontweight='bold')
|
||||
ax6.legend(fontsize=10)
|
||||
ax6.grid(True, alpha=0.3)
|
||||
|
||||
plt.suptitle(f'随机三体系统 (种子: {seed})', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图形
|
||||
output_file = f"random_system_seed_{seed}.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n图形已保存到: {output_file}")
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
|
||||
|
||||
def run_multiple_random_simulations(n_simulations: int = 5, total_time: float = 10.0):
|
||||
"""运行多个随机模拟并比较结果"""
|
||||
print("=" * 60)
|
||||
print(f"运行 {n_simulations} 个随机三体系统模拟")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
for sim_idx in range(n_simulations):
|
||||
seed = 100 + sim_idx # 不同的随机种子
|
||||
np.random.seed(seed)
|
||||
|
||||
print(f"\n模拟 {sim_idx+1}/{n_simulations} (种子: {seed})")
|
||||
|
||||
# 创建随机配置
|
||||
particles = ThreeBodyConfig.create_random_config(
|
||||
masses=None,
|
||||
position_range=2.0,
|
||||
velocity_scale=2.0 + np.random.random() * 2.0 # 随机速度缩放
|
||||
)
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
solver.simulate(total_time=total_time, progress_interval=5000)
|
||||
|
||||
# 分析结果
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 计算系统特性
|
||||
final_distances = []
|
||||
for i in range(3):
|
||||
for j in range(i+1, 3):
|
||||
final_dist = np.linalg.norm(trajectories[i][-1] - trajectories[j][-1])
|
||||
final_distances.append(final_dist)
|
||||
|
||||
avg_final_distance = np.mean(final_distances)
|
||||
std_final_distance = np.std(final_distances)
|
||||
|
||||
# 计算质心移动距离
|
||||
initial_com = solver.get_center_of_mass()
|
||||
# 需要重新计算初始质心
|
||||
total_mass = sum(p.mass for p in particles)
|
||||
initial_com = np.zeros(3)
|
||||
for p in particles:
|
||||
initial_com += p.mass * p.position
|
||||
initial_com /= total_mass
|
||||
|
||||
final_com = solver.get_center_of_mass()
|
||||
com_movement = np.linalg.norm(final_com - initial_com)
|
||||
|
||||
results.append({
|
||||
'seed': seed,
|
||||
'avg_final_distance': avg_final_distance,
|
||||
'std_final_distance': std_final_distance,
|
||||
'com_movement': com_movement,
|
||||
'energy_error': solver.get_conservation_errors()[2]
|
||||
})
|
||||
|
||||
print(f" 平均最终距离: {avg_final_distance:.3f} AU")
|
||||
print(f" 质心移动: {com_movement:.3f} AU")
|
||||
print(f" 能量相对误差: {solver.get_conservation_errors()[2]:.6e}")
|
||||
|
||||
# 绘制比较图
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
seeds = [r['seed'] for r in results]
|
||||
avg_distances = [r['avg_final_distance'] for r in results]
|
||||
com_movements = [r['com_movement'] for r in results]
|
||||
energy_errors = [r['energy_error'] for r in results]
|
||||
|
||||
# 平均最终距离
|
||||
axes[0, 0].bar(range(n_simulations), avg_distances, color='skyblue', edgecolor='black')
|
||||
axes[0, 0].set_xlabel('模拟编号', fontsize=12)
|
||||
axes[0, 0].set_ylabel('平均最终距离 (AU)', fontsize=12)
|
||||
axes[0, 0].set_title('质点间平均距离', fontsize=14, fontweight='bold')
|
||||
axes[0, 0].set_xticks(range(n_simulations))
|
||||
axes[0, 0].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
|
||||
axes[0, 0].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 质心移动
|
||||
axes[0, 1].bar(range(n_simulations), com_movements, color='lightgreen', edgecolor='black')
|
||||
axes[0, 1].set_xlabel('模拟编号', fontsize=12)
|
||||
axes[0, 1].set_ylabel('质心移动距离 (AU)', fontsize=12)
|
||||
axes[0, 1].set_title('系统质心移动', fontsize=14, fontweight='bold')
|
||||
axes[0, 1].set_xticks(range(n_simulations))
|
||||
axes[0, 1].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
|
||||
axes[0, 1].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 能量误差
|
||||
axes[1, 0].bar(range(n_simulations), energy_errors, color='lightcoral', edgecolor='black')
|
||||
axes[1, 0].set_xlabel('模拟编号', fontsize=12)
|
||||
axes[1, 0].set_ylabel('能量相对误差', fontsize=12)
|
||||
axes[1, 0].set_title('数值积分误差', fontsize=14, fontweight='bold')
|
||||
axes[1, 0].set_xticks(range(n_simulations))
|
||||
axes[1, 0].set_xticklabels([f"#{i+1}" for i in range(n_simulations)])
|
||||
axes[1, 0].set_yscale('log')
|
||||
axes[1, 0].grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 散点图:质心移动 vs 平均距离
|
||||
axes[1, 1].scatter(avg_distances, com_movements, s=100, c=range(n_simulations),
|
||||
cmap='viridis', edgecolors='black', alpha=0.8)
|
||||
axes[1, 1].set_xlabel('平均最终距离 (AU)', fontsize=12)
|
||||
axes[1, 1].set_ylabel('质心移动距离 (AU)', fontsize=12)
|
||||
axes[1, 1].set_title('系统稳定性关系', fontsize=14, fontweight='bold')
|
||||
|
||||
# 添加标签
|
||||
for i, (x, y) in enumerate(zip(avg_distances, com_movements)):
|
||||
axes[1, 1].annotate(f"#{i+1}", (x, y), textcoords="offset points",
|
||||
xytext=(0, 10), ha='center', fontsize=9)
|
||||
|
||||
axes[1, 1].grid(True, alpha=0.3)
|
||||
|
||||
plt.suptitle(f'{n_simulations}个随机三体系统模拟比较', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
output_file = "multiple_random_simulations.png"
|
||||
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
||||
print(f"\n比较图形已保存到: {output_file}")
|
||||
plt.show()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行单个随机示例
|
||||
print("运行单个随机三体系统示例...")
|
||||
solver = run_random_example(seed=42, total_time=15.0)
|
||||
|
||||
# 运行多个随机模拟(可选)
|
||||
# print("\n" + "="*60)
|
||||
# print("运行多个随机模拟比较...")
|
||||
# results = run_multiple_random_simulations(n_simulations=5, total_time=5.0)
|
||||
119
three_body_problem/integrator.py
Normal file
119
three_body_problem/integrator.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
数值积分器模块,提供RK4方法求解微分方程
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Callable, Tuple, List
|
||||
from .particle import Particle
|
||||
|
||||
|
||||
class RK4Integrator:
|
||||
"""四阶龙格-库塔法积分器"""
|
||||
|
||||
def __init__(self, dt: float = 0.001):
|
||||
"""
|
||||
初始化积分器
|
||||
|
||||
参数:
|
||||
dt: 时间步长(年)
|
||||
"""
|
||||
self.dt = dt
|
||||
|
||||
def step(self, particles: List[Particle],
|
||||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]]) -> List[Particle]:
|
||||
"""
|
||||
执行一个时间步的积分
|
||||
|
||||
参数:
|
||||
particles: 质点列表
|
||||
acceleration_func: 计算加速度的函数
|
||||
|
||||
返回:
|
||||
更新后的质点列表
|
||||
"""
|
||||
# 保存当前状态
|
||||
current_positions = [p.position.copy() for p in particles]
|
||||
current_velocities = [p.velocity.copy() for p in particles]
|
||||
|
||||
# RK4步骤1: k1
|
||||
k1_v = acceleration_func(particles)
|
||||
k1_r = current_velocities
|
||||
|
||||
# 临时更新位置和速度用于k2计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + 0.5 * self.dt * k1_r[i]
|
||||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k1_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤2: k2
|
||||
k2_v = acceleration_func(temp_particles)
|
||||
k2_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 临时更新位置和速度用于k3计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + 0.5 * self.dt * k2_r[i]
|
||||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k2_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤3: k3
|
||||
k3_v = acceleration_func(temp_particles)
|
||||
k3_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 临时更新位置和速度用于k4计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + self.dt * k3_r[i]
|
||||
temp_p.velocity = current_velocities[i] + self.dt * k3_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤4: k4
|
||||
k4_v = acceleration_func(temp_particles)
|
||||
k4_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 组合所有k值计算最终更新
|
||||
new_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
# 计算新的速度
|
||||
new_velocity = (current_velocities[i] +
|
||||
self.dt / 6.0 * (k1_v[i] + 2*k2_v[i] + 2*k3_v[i] + k4_v[i]))
|
||||
|
||||
# 计算新的位置
|
||||
new_position = (current_positions[i] +
|
||||
self.dt / 6.0 * (k1_r[i] + 2*k2_r[i] + 2*k3_r[i] + k4_r[i]))
|
||||
|
||||
# 创建新的质点对象
|
||||
new_p = p.copy()
|
||||
new_p.update(new_position, new_velocity)
|
||||
new_particles.append(new_p)
|
||||
|
||||
return new_particles
|
||||
|
||||
def integrate(self, particles: List[Particle],
|
||||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]],
|
||||
steps: int) -> List[Particle]:
|
||||
"""
|
||||
执行多步积分
|
||||
|
||||
参数:
|
||||
particles: 初始质点列表
|
||||
acceleration_func: 计算加速度的函数
|
||||
steps: 积分步数
|
||||
|
||||
返回:
|
||||
积分结束后的质点列表
|
||||
"""
|
||||
current_particles = particles
|
||||
|
||||
for step in range(steps):
|
||||
current_particles = self.step(current_particles, acceleration_func)
|
||||
|
||||
# 每1000步打印进度
|
||||
if step % 1000 == 0:
|
||||
print(f"积分进度: {step}/{steps} 步")
|
||||
|
||||
return current_particles
|
||||
64
three_body_problem/particle.py
Normal file
64
three_body_problem/particle.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
质点类,表示三体问题中的一个天体
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
class Particle:
|
||||
"""表示一个质点的类"""
|
||||
|
||||
def __init__(self, mass: float, position: np.ndarray, velocity: np.ndarray,
|
||||
name: str = "", color: str = None):
|
||||
"""
|
||||
初始化质点
|
||||
|
||||
参数:
|
||||
mass: 质量(太阳质量)
|
||||
position: 位置向量 (x, y, z) [AU]
|
||||
velocity: 速度向量 (vx, vy, vz) [AU/年]
|
||||
name: 质点名称
|
||||
color: 可视化颜色
|
||||
"""
|
||||
self.mass = mass
|
||||
self.position = np.array(position, dtype=np.float64)
|
||||
self.velocity = np.array(velocity, dtype=np.float64)
|
||||
self.name = name
|
||||
self.color = color
|
||||
|
||||
# 存储轨迹历史
|
||||
self.position_history = []
|
||||
self.velocity_history = []
|
||||
|
||||
def update(self, new_position: np.ndarray, new_velocity: np.ndarray):
|
||||
"""更新质点的状态并记录历史"""
|
||||
self.position_history.append(self.position.copy())
|
||||
self.velocity_history.append(self.velocity.copy())
|
||||
self.position = new_position
|
||||
self.velocity = new_velocity
|
||||
|
||||
def get_trajectory(self) -> np.ndarray:
|
||||
"""获取轨迹历史"""
|
||||
if not self.position_history:
|
||||
return np.array([self.position])
|
||||
return np.array(self.position_history + [self.position])
|
||||
|
||||
def get_energy(self) -> float:
|
||||
"""计算质点的动能"""
|
||||
v_squared = np.sum(self.velocity**2)
|
||||
return 0.5 * self.mass * v_squared
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"Particle(name='{self.name}', mass={self.mass:.3f}, "
|
||||
f"position={self.position}, velocity={self.velocity})")
|
||||
|
||||
def copy(self) -> 'Particle':
|
||||
"""创建质点的副本"""
|
||||
return Particle(
|
||||
mass=self.mass,
|
||||
position=self.position.copy(),
|
||||
velocity=self.velocity.copy(),
|
||||
name=self.name,
|
||||
color=self.color
|
||||
)
|
||||
6
three_body_problem/requirements.txt
Normal file
6
three_body_problem/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# 三体问题求解器依赖
|
||||
numpy>=1.21.0
|
||||
matplotlib>=3.5.0
|
||||
|
||||
# 测试依赖(可选)
|
||||
pytest>=7.0.0
|
||||
199
three_body_problem/run_example.py
Normal file
199
three_body_problem/run_example.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
三体问题求解器 - 快速示例
|
||||
运行此文件以查看三体问题的基本模拟
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加当前目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from three_body_problem import ThreeBodySolver, Particle, ThreeBodyConfig, ThreeBodyVisualizer
|
||||
|
||||
|
||||
def simple_example():
|
||||
"""简单示例:自定义三体系统"""
|
||||
print("=" * 60)
|
||||
print("三体问题求解器 - 简单示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建三个质点
|
||||
print("\n1. 创建三个质点...")
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0], name="Star A", color="red"),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0], name="Star B", color="green"),
|
||||
Particle(mass=0.5, position=[0.0, 1.0, 0.0], velocity=[-0.3, 0.0, 0.0], name="Star C", color="blue")
|
||||
]
|
||||
|
||||
for i, p in enumerate(particles):
|
||||
print(f" 质点 {i+1}: {p.name}, 质量: {p.mass:.2f}, 位置: [{p.position[0]:.2f}, {p.position[1]:.2f}, {p.position[2]:.2f}]")
|
||||
|
||||
# 创建求解器
|
||||
print("\n2. 创建求解器...")
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟2年
|
||||
print("\n3. 模拟2年运动...")
|
||||
print(" 开始积分...")
|
||||
solver.simulate(total_time=2.0, progress_interval=500)
|
||||
|
||||
# 分析结果
|
||||
print("\n4. 分析结果...")
|
||||
trajectories = solver.get_trajectories()
|
||||
print(f" 模拟完成! 生成了 {len(trajectories[0])} 个轨迹点")
|
||||
|
||||
com = solver.get_center_of_mass()
|
||||
print(f" 系统质心位置: [{com[0]:.4f}, {com[1]:.4f}, {com[2]:.4f}] AU")
|
||||
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
print(f" 守恒定律误差:")
|
||||
print(f" - 动量误差: {momentum_error:.2e}")
|
||||
print(f" - 角动量误差: {angular_momentum_error:.2e}")
|
||||
print(f" - 能量相对误差: {energy_error:.2e}")
|
||||
|
||||
# 可视化
|
||||
print("\n5. 生成可视化图形...")
|
||||
visualizer = ThreeBodyVisualizer(figsize=(12, 8))
|
||||
|
||||
# 创建3D轨迹图
|
||||
visualizer.create_3d_plot()
|
||||
visualizer.plot_trajectories(solver, title="三体系统运动轨迹")
|
||||
|
||||
# 保存图形
|
||||
output_file = "simple_example_3d.png"
|
||||
visualizer.save_figure(output_file, dpi=200)
|
||||
print(f" 3D轨迹图已保存到: {output_file}")
|
||||
|
||||
# 创建2D投影图
|
||||
fig, ax = visualizer.plot_2d_projection(solver, projection='xy', title="XY平面投影")
|
||||
output_file_2d = "simple_example_2d.png"
|
||||
fig.savefig(output_file_2d, dpi=200, bbox_inches='tight')
|
||||
print(f" 2D投影图已保存到: {output_file_2d}")
|
||||
|
||||
print("\n6. 显示图形...")
|
||||
plt.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def figure8_example():
|
||||
"""8字形轨道示例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("8字形轨道示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用预置的8字形轨道配置
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
# 打印配置信息
|
||||
ThreeBodyConfig.print_config_summary(particles)
|
||||
|
||||
# 创建求解器
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 模拟5年
|
||||
print("\n模拟5年8字形轨道...")
|
||||
solver.simulate(total_time=5.0, progress_interval=1000)
|
||||
|
||||
# 可视化
|
||||
visualizer = ThreeBodyVisualizer(figsize=(10, 8))
|
||||
visualizer.create_3d_plot()
|
||||
visualizer.plot_trajectories(solver, title="8字形三体轨道")
|
||||
|
||||
output_file = "figure8_example.png"
|
||||
visualizer.save_figure(output_file, dpi=200)
|
||||
print(f"\n图形已保存到: {output_file}")
|
||||
|
||||
plt.show()
|
||||
|
||||
return solver
|
||||
|
||||
|
||||
def quick_test():
|
||||
"""快速测试所有模块"""
|
||||
print("=" * 60)
|
||||
print("快速测试所有模块")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试1: 创建质点
|
||||
print("\n1. 测试质点创建...")
|
||||
p1 = Particle(mass=1.0, position=[1, 0, 0], velocity=[0, 1, 0], name="Test Star")
|
||||
print(f" 创建质点: {p1.name}, 质量: {p1.mass}, 位置: {p1.position}")
|
||||
|
||||
# 测试2: 创建配置
|
||||
print("\n2. 测试配置创建...")
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
print(f" 创建了 {len(particles)} 个质点")
|
||||
|
||||
# 测试3: 创建求解器
|
||||
print("\n3. 测试求解器创建...")
|
||||
solver = ThreeBodySolver(particles, dt=0.01)
|
||||
print(f" 求解器创建成功,时间步长: {solver.dt}")
|
||||
|
||||
# 测试4: 单步积分
|
||||
print("\n4. 测试单步积分...")
|
||||
initial_positions = [p.position.copy() for p in particles]
|
||||
solver.step()
|
||||
for i, (old_pos, particle) in enumerate(zip(initial_positions, solver.particles)):
|
||||
moved = not np.allclose(old_pos, particle.position)
|
||||
print(f" 质点 {i+1} 位置变化: {'是' if moved else '否'}")
|
||||
|
||||
# 测试5: 质心计算
|
||||
print("\n5. 测试质心计算...")
|
||||
com = solver.get_center_of_mass()
|
||||
print(f" 系统质心: [{com[0]:.4f}, {com[1]:.4f}, {com[2]:.4f}]")
|
||||
|
||||
# 测试6: 能量计算
|
||||
print("\n6. 测试能量计算...")
|
||||
energy = solver._calculate_total_energy()
|
||||
print(f" 系统总能量: {energy:.6e}")
|
||||
|
||||
print("\n所有测试通过!")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("三体问题求解器示例")
|
||||
print("=" * 60)
|
||||
print("选择要运行的示例:")
|
||||
print(" 1. 简单示例 (自定义三体系统)")
|
||||
print(" 2. 8字形轨道示例")
|
||||
print(" 3. 快速测试所有模块")
|
||||
print(" 4. 全部运行")
|
||||
|
||||
try:
|
||||
choice = input("\n请输入选择 (1-4): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
simple_example()
|
||||
elif choice == "2":
|
||||
figure8_example()
|
||||
elif choice == "3":
|
||||
quick_test()
|
||||
elif choice == "4":
|
||||
quick_test()
|
||||
simple_example()
|
||||
figure8_example()
|
||||
else:
|
||||
print("无效选择,运行简单示例...")
|
||||
simple_example()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n用户中断")
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("示例运行完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
three_body_problem/setup.py
Normal file
53
three_body_problem/setup.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
with open("requirements.txt", "r", encoding="utf-8") as fh:
|
||||
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")]
|
||||
|
||||
setup(
|
||||
name="three-body-problem",
|
||||
version="1.0.0",
|
||||
author="ThreeBodyProblem Team",
|
||||
author_email="threebody@example.com",
|
||||
description="A pure Python solver for the three-body problem",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/yourusername/three-body-problem",
|
||||
packages=find_packages(),
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Topic :: Scientific/Engineering",
|
||||
"Topic :: Scientific/Engineering :: Astronomy",
|
||||
"Topic :: Scientific/Engineering :: Physics",
|
||||
],
|
||||
python_requires=">=3.7",
|
||||
install_requires=requirements,
|
||||
extras_require={
|
||||
"dev": [
|
||||
"pytest>=7.0.0",
|
||||
"black>=22.0.0",
|
||||
"flake8>=5.0.0",
|
||||
],
|
||||
"docs": [
|
||||
"sphinx>=5.0.0",
|
||||
"sphinx-rtd-theme>=1.0.0",
|
||||
],
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"three-body-demo=three_body_problem.demo:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
196
three_body_problem/solver.py
Normal file
196
three_body_problem/solver.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
三体问题求解器主模块
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Callable
|
||||
from .particle import Particle
|
||||
from .integrator import RK4Integrator
|
||||
|
||||
|
||||
class ThreeBodySolver:
|
||||
"""三体问题求解器"""
|
||||
|
||||
# 万有引力常数 (AU^3 / (M_sun * year^2))
|
||||
G = 4 * np.pi**2 # 在天文单位制中
|
||||
|
||||
def __init__(self, particles: List[Particle], dt: float = 0.001):
|
||||
"""
|
||||
初始化三体问题求解器
|
||||
|
||||
参数:
|
||||
particles: 三个质点的列表
|
||||
dt: 时间步长(年)
|
||||
"""
|
||||
if len(particles) != 3:
|
||||
raise ValueError("三体问题需要恰好3个质点")
|
||||
|
||||
self.particles = particles
|
||||
self.dt = dt
|
||||
self.integrator = RK4Integrator(dt)
|
||||
self.time = 0.0
|
||||
|
||||
# 验证总动量和角动量守恒(可选)
|
||||
self.initial_total_momentum = self._calculate_total_momentum()
|
||||
self.initial_angular_momentum = self._calculate_total_angular_momentum()
|
||||
|
||||
def _calculate_accelerations(self, particles: List[Particle]) -> List[np.ndarray]:
|
||||
"""
|
||||
计算所有质点的加速度
|
||||
|
||||
参数:
|
||||
particles: 质点列表
|
||||
|
||||
返回:
|
||||
加速度列表
|
||||
"""
|
||||
accelerations = [np.zeros(3) for _ in range(3)]
|
||||
|
||||
# 计算每对质点之间的引力
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
if i != j:
|
||||
# 计算相对位置向量
|
||||
r_vec = particles[j].position - particles[i].position
|
||||
r = np.linalg.norm(r_vec)
|
||||
|
||||
# 避免除以零
|
||||
if r < 1e-10:
|
||||
r = 1e-10
|
||||
|
||||
# 计算加速度 (F = ma -> a = F/m)
|
||||
acceleration = self.G * particles[j].mass * r_vec / (r**3)
|
||||
accelerations[i] += acceleration
|
||||
|
||||
return accelerations
|
||||
|
||||
def _calculate_total_momentum(self) -> np.ndarray:
|
||||
"""计算系统总动量"""
|
||||
total_momentum = np.zeros(3)
|
||||
for p in self.particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
return total_momentum
|
||||
|
||||
def _calculate_total_angular_momentum(self) -> np.ndarray:
|
||||
"""计算系统总角动量"""
|
||||
total_angular_momentum = np.zeros(3)
|
||||
for p in self.particles:
|
||||
# 角动量 L = r × p = r × (m*v)
|
||||
angular_momentum = np.cross(p.position, p.mass * p.velocity)
|
||||
total_angular_momentum += angular_momentum
|
||||
return total_angular_momentum
|
||||
|
||||
def _calculate_total_energy(self) -> float:
|
||||
"""计算系统总能量(动能+势能)"""
|
||||
kinetic_energy = 0.0
|
||||
potential_energy = 0.0
|
||||
|
||||
# 计算动能
|
||||
for p in self.particles:
|
||||
kinetic_energy += 0.5 * p.mass * np.sum(p.velocity**2)
|
||||
|
||||
# 计算势能
|
||||
for i in range(3):
|
||||
for j in range(i+1, 3):
|
||||
r_vec = self.particles[j].position - self.particles[i].position
|
||||
r = np.linalg.norm(r_vec)
|
||||
if r > 1e-10: # 避免除以零
|
||||
potential_energy -= self.G * self.particles[i].mass * self.particles[j].mass / r
|
||||
|
||||
return kinetic_energy + potential_energy
|
||||
|
||||
def step(self) -> List[Particle]:
|
||||
"""
|
||||
执行一个时间步的积分
|
||||
|
||||
返回:
|
||||
更新后的质点列表
|
||||
"""
|
||||
# 计算加速度
|
||||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||||
|
||||
# 执行积分
|
||||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||||
self.time += self.dt
|
||||
|
||||
return self.particles
|
||||
|
||||
def simulate(self, total_time: float, progress_interval: int = 1000) -> List[Particle]:
|
||||
"""
|
||||
模拟三体运动
|
||||
|
||||
参数:
|
||||
total_time: 总模拟时间(年)
|
||||
progress_interval: 进度打印间隔步数
|
||||
|
||||
返回:
|
||||
模拟结束后的质点列表
|
||||
"""
|
||||
steps = int(total_time / self.dt)
|
||||
|
||||
print(f"开始模拟: 总时间={total_time:.2f}年, 步数={steps}, 步长={self.dt:.6f}年")
|
||||
print(f"初始总能量: {self._calculate_total_energy():.6e}")
|
||||
|
||||
acceleration_func = lambda p: self._calculate_accelerations(p)
|
||||
|
||||
for step in range(steps):
|
||||
self.particles = self.integrator.step(self.particles, acceleration_func)
|
||||
self.time += self.dt
|
||||
|
||||
# 打印进度
|
||||
if step % progress_interval == 0:
|
||||
energy = self._calculate_total_energy()
|
||||
print(f"进度: {step}/{steps}步, 时间={self.time:.3f}年, 能量={energy:.6e}")
|
||||
|
||||
final_energy = self._calculate_total_energy()
|
||||
print(f"模拟完成: 最终时间={self.time:.2f}年, 最终能量={final_energy:.6e}")
|
||||
|
||||
return self.particles
|
||||
|
||||
def get_trajectories(self) -> List[np.ndarray]:
|
||||
"""获取所有质点的轨迹"""
|
||||
return [p.get_trajectory() for p in self.particles]
|
||||
|
||||
def get_center_of_mass(self) -> np.ndarray:
|
||||
"""计算系统质心位置"""
|
||||
total_mass = sum(p.mass for p in self.particles)
|
||||
com = np.zeros(3)
|
||||
for p in self.particles:
|
||||
com += p.mass * p.position
|
||||
return com / total_mass if total_mass > 0 else np.zeros(3)
|
||||
|
||||
def get_conservation_errors(self) -> Tuple[float, float, float]:
|
||||
"""
|
||||
计算守恒定律的误差
|
||||
|
||||
返回:
|
||||
(动量误差, 角动量误差, 能量误差)
|
||||
"""
|
||||
# 动量误差
|
||||
current_momentum = self._calculate_total_momentum()
|
||||
momentum_error = np.linalg.norm(current_momentum - self.initial_total_momentum)
|
||||
|
||||
# 角动量误差
|
||||
current_angular_momentum = self._calculate_total_angular_momentum()
|
||||
angular_momentum_error = np.linalg.norm(
|
||||
current_angular_momentum - self.initial_angular_momentum
|
||||
)
|
||||
|
||||
# 能量误差(相对误差)
|
||||
initial_energy = self._calculate_total_energy() # 重新计算初始能量
|
||||
current_energy = self._calculate_total_energy()
|
||||
if abs(initial_energy) > 1e-10:
|
||||
energy_error = abs((current_energy - initial_energy) / initial_energy)
|
||||
else:
|
||||
energy_error = abs(current_energy - initial_energy)
|
||||
|
||||
return momentum_error, angular_momentum_error, energy_error
|
||||
|
||||
def reset(self):
|
||||
"""重置求解器状态(清除历史记录)"""
|
||||
for p in self.particles:
|
||||
p.position_history.clear()
|
||||
p.velocity_history.clear()
|
||||
self.time = 0.0
|
||||
self.initial_total_momentum = self._calculate_total_momentum()
|
||||
self.initial_angular_momentum = self._calculate_total_angular_momentum()
|
||||
3
three_body_problem/tests/__init__.py
Normal file
3
three_body_problem/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
三体问题测试模块
|
||||
"""
|
||||
385
three_body_problem/tests/test_solver.py
Normal file
385
three_body_problem/tests/test_solver.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
三体问题求解器测试
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig
|
||||
|
||||
|
||||
class TestParticle:
|
||||
"""测试质点类"""
|
||||
|
||||
def test_particle_creation(self):
|
||||
"""测试质点创建"""
|
||||
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
|
||||
|
||||
assert p.mass == 1.0
|
||||
assert np.allclose(p.position, [1.0, 2.0, 3.0])
|
||||
assert np.allclose(p.velocity, [0.1, 0.2, 0.3])
|
||||
assert p.name == "Test"
|
||||
assert len(p.position_history) == 0
|
||||
assert len(p.velocity_history) == 0
|
||||
|
||||
def test_particle_update(self):
|
||||
"""测试质点状态更新"""
|
||||
p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
|
||||
new_position = np.array([1.0, 2.0, 3.0])
|
||||
new_velocity = np.array([0.1, 0.2, 0.3])
|
||||
|
||||
p.update(new_position, new_velocity)
|
||||
|
||||
assert np.allclose(p.position, new_position)
|
||||
assert np.allclose(p.velocity, new_velocity)
|
||||
assert len(p.position_history) == 1
|
||||
assert len(p.velocity_history) == 1
|
||||
assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0])
|
||||
assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0])
|
||||
|
||||
def test_particle_energy(self):
|
||||
"""测试质点动能计算"""
|
||||
p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0])
|
||||
|
||||
# 动能 = 0.5 * m * v^2
|
||||
expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2)
|
||||
calculated_energy = p.get_energy()
|
||||
|
||||
assert np.isclose(calculated_energy, expected_energy)
|
||||
|
||||
def test_particle_copy(self):
|
||||
"""测试质点复制"""
|
||||
p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original")
|
||||
p2 = p1.copy()
|
||||
|
||||
# 检查值相等
|
||||
assert p2.mass == p1.mass
|
||||
assert np.allclose(p2.position, p1.position)
|
||||
assert np.allclose(p2.velocity, p1.velocity)
|
||||
assert p2.name == p1.name
|
||||
|
||||
# 检查是深拷贝
|
||||
p2.position[0] = 100.0
|
||||
assert p1.position[0] == 1.0 # 原始对象不应改变
|
||||
|
||||
|
||||
class TestThreeBodySolver:
|
||||
"""测试三体问题求解器"""
|
||||
|
||||
def test_solver_creation(self):
|
||||
"""测试求解器创建"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
|
||||
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
assert len(solver.particles) == 3
|
||||
assert solver.dt == 0.001
|
||||
assert solver.time == 0.0
|
||||
|
||||
def test_solver_wrong_number_of_particles(self):
|
||||
"""测试错误数量的质点"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="三体问题需要恰好3个质点"):
|
||||
ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
def test_acceleration_calculation(self):
|
||||
"""测试加速度计算"""
|
||||
# 创建简单的测试配置:三个质点在等边三角形顶点
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 计算加速度
|
||||
accelerations = solver._calculate_accelerations(particles)
|
||||
|
||||
# 检查加速度形状
|
||||
assert len(accelerations) == 3
|
||||
for acc in accelerations:
|
||||
assert acc.shape == (3,)
|
||||
|
||||
# 由于对称性,第一个质点的加速度应该指向其他两个质点的方向
|
||||
# 这里只检查计算是否成功,不检查具体值
|
||||
|
||||
def test_center_of_mass(self):
|
||||
"""测试质心计算"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
com = solver.get_center_of_mass()
|
||||
|
||||
# 计算期望的质心
|
||||
total_mass = 1.0 + 2.0 + 3.0
|
||||
expected_com = np.array([
|
||||
(1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass,
|
||||
(1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass,
|
||||
(1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass
|
||||
])
|
||||
|
||||
assert np.allclose(com, expected_com)
|
||||
|
||||
def test_energy_calculation(self):
|
||||
"""测试能量计算"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
energy = solver._calculate_total_energy()
|
||||
|
||||
# 能量应该是有限的实数
|
||||
assert np.isfinite(energy)
|
||||
assert isinstance(energy, float)
|
||||
|
||||
def test_single_step(self):
|
||||
"""测试单步积分"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]),
|
||||
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 保存初始位置
|
||||
initial_positions = [p.position.copy() for p in particles]
|
||||
|
||||
# 执行单步积分
|
||||
new_particles = solver.step()
|
||||
|
||||
# 检查位置是否改变
|
||||
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
|
||||
assert not np.allclose(old_pos, new_particle.position)
|
||||
|
||||
# 检查时间是否增加
|
||||
assert solver.time == 0.001
|
||||
|
||||
def test_conservation_laws(self):
|
||||
"""测试守恒定律(动量、角动量、能量)"""
|
||||
# 使用8字形轨道配置(应该是稳定的)
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.0001)
|
||||
|
||||
# 模拟很短时间
|
||||
solver.simulate(total_time=0.1, progress_interval=100)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
|
||||
# 误差应该很小
|
||||
assert momentum_error < 1e-10
|
||||
assert angular_momentum_error < 1e-10
|
||||
assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差
|
||||
|
||||
|
||||
class TestThreeBodyConfig:
|
||||
"""测试配置管理"""
|
||||
|
||||
def test_figure8_config(self):
|
||||
"""测试8字形轨道配置"""
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
assert len(particles) == 3
|
||||
assert all(p.mass == 1.0 for p in particles)
|
||||
|
||||
# 检查总动量接近零(8字形轨道应该满足)
|
||||
total_momentum = np.zeros(3)
|
||||
for p in particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
|
||||
assert np.linalg.norm(total_momentum) < 1e-10
|
||||
|
||||
def test_lagrange_config(self):
|
||||
"""测试拉格朗日点配置"""
|
||||
# 测试L4点
|
||||
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
assert len(particles_l4) == 3
|
||||
|
||||
# 检查质量
|
||||
assert particles_l4[0].mass == 1.0 # 太阳
|
||||
assert particles_l4[1].mass == 3e-6 # 地球
|
||||
assert particles_l4[2].mass == 1e-8 # 测试质点
|
||||
|
||||
# 测试L5点
|
||||
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
|
||||
assert len(particles_l5) == 3
|
||||
|
||||
# 检查位置(L4和L5应该在等边三角形顶点)
|
||||
r_l4 = particles_l4[2].position
|
||||
r_l5 = particles_l5[2].position
|
||||
|
||||
assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0])
|
||||
assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0])
|
||||
|
||||
def test_random_config(self):
|
||||
"""测试随机配置"""
|
||||
particles = ThreeBodyConfig.create_random_config()
|
||||
|
||||
assert len(particles) == 3
|
||||
assert all(0.5 <= p.mass <= 2.0 for p in particles)
|
||||
|
||||
# 检查总动量接近零(随机配置会调整速度使总动量为零)
|
||||
total_momentum = np.zeros(3)
|
||||
for p in particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
|
||||
assert np.linalg.norm(total_momentum) < 1e-10
|
||||
|
||||
def test_custom_config(self):
|
||||
"""测试自定义配置"""
|
||||
config_dict = {
|
||||
'particle_1': {
|
||||
'mass': 1.0,
|
||||
'position': [1.0, 0.0, 0.0],
|
||||
'velocity': [0.0, 1.0, 0.0],
|
||||
'name': 'Star A',
|
||||
'color': 'red'
|
||||
},
|
||||
'particle_2': {
|
||||
'mass': 2.0,
|
||||
'position': [-1.0, 0.0, 0.0],
|
||||
'velocity': [0.0, -0.5, 0.0],
|
||||
'name': 'Star B',
|
||||
'color': 'green'
|
||||
},
|
||||
'particle_3': {
|
||||
'mass': 0.5,
|
||||
'position': [0.0, 1.0, 0.0],
|
||||
'velocity': [0.5, 0.0, 0.0],
|
||||
'name': 'Star C',
|
||||
'color': 'blue'
|
||||
}
|
||||
}
|
||||
|
||||
particles = ThreeBodyConfig.create_custom_config(config_dict)
|
||||
|
||||
assert len(particles) == 3
|
||||
assert particles[0].name == 'Star A'
|
||||
assert particles[1].mass == 2.0
|
||||
assert np.allclose(particles[2].position, [0.0, 1.0, 0.0])
|
||||
|
||||
|
||||
def test_integration_accuracy():
|
||||
"""测试数值积分精度"""
|
||||
# 使用简单的二体问题测试(第三个体质量很小)
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]),
|
||||
Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点
|
||||
]
|
||||
|
||||
# 测试不同时间步长
|
||||
dts = [0.01, 0.005, 0.001, 0.0005]
|
||||
energy_errors = []
|
||||
|
||||
for dt in dts:
|
||||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||||
solver.simulate(total_time=1.0, progress_interval=10000)
|
||||
|
||||
_, _, energy_error = solver.get_conservation_errors()
|
||||
energy_errors.append(energy_error)
|
||||
|
||||
# 检查误差随步长减小而减小(四阶方法)
|
||||
for i in range(len(energy_errors)-1):
|
||||
# 误差应该大致按 dt^4 减小
|
||||
error_ratio = energy_errors[i] / energy_errors[i+1]
|
||||
dt_ratio = (dts[i] / dts[i+1])**4
|
||||
# 允许一定的误差范围
|
||||
assert 0.1 < error_ratio / dt_ratio < 10.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有测试
|
||||
print("运行三体问题求解器测试...")
|
||||
|
||||
# 创建测试实例
|
||||
test_particle = TestParticle()
|
||||
test_solver = TestThreeBodySolver()
|
||||
test_config = TestThreeBodyConfig()
|
||||
|
||||
# 运行粒子测试
|
||||
print("\n1. 测试质点类:")
|
||||
test_particle.test_particle_creation()
|
||||
print(" ✓ test_particle_creation 通过")
|
||||
|
||||
test_particle.test_particle_update()
|
||||
print(" ✓ test_particle_update 通过")
|
||||
|
||||
test_particle.test_particle_energy()
|
||||
print(" ✓ test_particle_energy 通过")
|
||||
|
||||
test_particle.test_particle_copy()
|
||||
print(" ✓ test_particle_copy 通过")
|
||||
|
||||
# 运行求解器测试
|
||||
print("\n2. 测试求解器类:")
|
||||
test_solver.test_solver_creation()
|
||||
print(" ✓ test_solver_creation 通过")
|
||||
|
||||
try:
|
||||
test_solver.test_solver_wrong_number_of_particles()
|
||||
print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常")
|
||||
except ValueError:
|
||||
print(" ✓ test_solver_wrong_number_of_particles 通过")
|
||||
|
||||
test_solver.test_acceleration_calculation()
|
||||
print(" ✓ test_acceleration_calculation 通过")
|
||||
|
||||
test_solver.test_center_of_mass()
|
||||
print(" ✓ test_center_of_mass 通过")
|
||||
|
||||
test_solver.test_energy_calculation()
|
||||
print(" ✓ test_energy_calculation 通过")
|
||||
|
||||
test_solver.test_single_step()
|
||||
print(" ✓ test_single_step 通过")
|
||||
|
||||
test_solver.test_conservation_laws()
|
||||
print(" ✓ test_conservation_laws 通过")
|
||||
|
||||
# 运行配置测试
|
||||
print("\n3. 测试配置管理:")
|
||||
test_config.test_figure8_config()
|
||||
print(" ✓ test_figure8_config 通过")
|
||||
|
||||
test_config.test_lagrange_config()
|
||||
print(" ✓ test_lagrange_config 通过")
|
||||
|
||||
test_config.test_random_config()
|
||||
print(" ✓ test_random_config 通过")
|
||||
|
||||
test_config.test_custom_config()
|
||||
print(" ✓ test_custom_config 通过")
|
||||
|
||||
# 运行积分精度测试
|
||||
print("\n4. 测试数值积分精度:")
|
||||
test_integration_accuracy()
|
||||
print(" ✓ test_integration_accuracy 通过")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("所有测试通过!")
|
||||
print("="*60)
|
||||
321
three_body_problem/visualizer.py
Normal file
321
three_body_problem/visualizer.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
三体问题可视化模块
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from typing import List, Optional, Tuple
|
||||
from .particle import Particle
|
||||
from .solver import ThreeBodySolver
|
||||
|
||||
|
||||
class ThreeBodyVisualizer:
|
||||
"""三体问题可视化类"""
|
||||
|
||||
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
|
||||
"""
|
||||
初始化可视化器
|
||||
|
||||
参数:
|
||||
figsize: 图形大小
|
||||
"""
|
||||
self.figsize = figsize
|
||||
self.fig = None
|
||||
self.ax = None
|
||||
|
||||
def create_3d_plot(self):
|
||||
"""创建3D图形"""
|
||||
self.fig = plt.figure(figsize=self.figsize)
|
||||
self.ax = self.fig.add_subplot(111, projection='3d')
|
||||
|
||||
def plot_trajectories(self, solver: ThreeBodySolver,
|
||||
show_current_positions: bool = True,
|
||||
show_com: bool = True,
|
||||
title: str = "三体问题轨迹"):
|
||||
"""
|
||||
绘制三体运动轨迹
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
show_current_positions: 是否显示当前位置
|
||||
show_com: 是否显示质心
|
||||
title: 图形标题
|
||||
"""
|
||||
if self.ax is None:
|
||||
self.create_3d_plot()
|
||||
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 绘制每个质点的轨迹
|
||||
colors = ['red', 'green', 'blue']
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 绘制轨迹线
|
||||
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制当前位置
|
||||
if show_current_positions and len(traj) > 0:
|
||||
current_pos = traj[-1]
|
||||
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心
|
||||
if show_com:
|
||||
com = solver.get_center_of_mass()
|
||||
self.ax.scatter(com[0], com[1], com[2],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
# 设置图形属性
|
||||
self.ax.set_xlabel('X (AU)', fontsize=12)
|
||||
self.ax.set_ylabel('Y (AU)', fontsize=12)
|
||||
self.ax.set_zlabel('Z (AU)', fontsize=12)
|
||||
self.ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
self.ax.legend(fontsize=10)
|
||||
self.ax.grid(True, alpha=0.3)
|
||||
|
||||
# 设置等比例坐标轴
|
||||
self.ax.set_aspect('auto')
|
||||
|
||||
def plot_2d_projection(self, solver: ThreeBodySolver,
|
||||
projection: str = 'xy',
|
||||
show_current_positions: bool = True,
|
||||
show_com: bool = True,
|
||||
title: str = "三体问题轨迹 (2D投影)"):
|
||||
"""
|
||||
绘制2D投影
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
projection: 投影平面 ('xy', 'xz', 'yz')
|
||||
show_current_positions: 是否显示当前位置
|
||||
show_com: 是否显示质心
|
||||
title: 图形标题
|
||||
"""
|
||||
if projection not in ['xy', 'xz', 'yz']:
|
||||
raise ValueError("projection 必须是 'xy', 'xz' 或 'yz'")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 确定坐标轴索引
|
||||
if projection == 'xy':
|
||||
x_idx, y_idx = 0, 1
|
||||
x_label, y_label = 'X (AU)', 'Y (AU)'
|
||||
elif projection == 'xz':
|
||||
x_idx, y_idx = 0, 2
|
||||
x_label, y_label = 'X (AU)', 'Z (AU)'
|
||||
else: # 'yz'
|
||||
x_idx, y_idx = 1, 2
|
||||
x_label, y_label = 'Y (AU)', 'Z (AU)'
|
||||
|
||||
# 绘制每个质点的轨迹
|
||||
colors = ['red', 'green', 'blue']
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 绘制轨迹线
|
||||
ax.plot(traj[:, x_idx], traj[:, y_idx],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制当前位置
|
||||
if show_current_positions and len(traj) > 0:
|
||||
current_pos = traj[-1]
|
||||
ax.scatter(current_pos[x_idx], current_pos[y_idx],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心
|
||||
if show_com:
|
||||
com = solver.get_center_of_mass()
|
||||
ax.scatter(com[x_idx], com[y_idx],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
# 设置图形属性
|
||||
ax.set_xlabel(x_label, fontsize=12)
|
||||
ax.set_ylabel(y_label, fontsize=12)
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_aspect('equal', adjustable='box')
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_energy_conservation(self, solver: ThreeBodySolver,
|
||||
time_points: Optional[np.ndarray] = None):
|
||||
"""
|
||||
绘制能量守恒图
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
time_points: 时间点数组(如果为None,则使用模拟时间)
|
||||
"""
|
||||
# 从历史记录中提取能量信息
|
||||
# 注意:这里需要修改求解器以记录能量历史
|
||||
# 暂时返回一个简单的图形
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# 这里可以添加能量历史记录功能
|
||||
ax.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax.set_ylabel('总能量', fontsize=12)
|
||||
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_phase_space(self, solver: ThreeBodySolver,
|
||||
particle_index: int = 0,
|
||||
dimension: str = 'x'):
|
||||
"""
|
||||
绘制相空间图(位置 vs 速度)
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
particle_index: 质点索引 (0, 1, 2)
|
||||
dimension: 维度 ('x', 'y', 'z')
|
||||
"""
|
||||
if particle_index not in [0, 1, 2]:
|
||||
raise ValueError("particle_index 必须是 0, 1, 或 2")
|
||||
|
||||
if dimension not in ['x', 'y', 'z']:
|
||||
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
|
||||
|
||||
particle = solver.particles[particle_index]
|
||||
trajectories = particle.get_trajectory()
|
||||
|
||||
if len(trajectories) < 2:
|
||||
print("警告:轨迹数据不足,无法绘制相空间图")
|
||||
return None, None
|
||||
|
||||
# 获取位置和速度历史
|
||||
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
|
||||
|
||||
# 注意:速度历史需要从求解器中获取
|
||||
# 暂时使用位置差分近似速度
|
||||
if len(positions) > 1:
|
||||
dt = solver.dt
|
||||
velocities = np.gradient(positions, dt)
|
||||
else:
|
||||
velocities = np.zeros_like(positions)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# 绘制相空间轨迹
|
||||
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
|
||||
cmap='viridis', alpha=0.7, s=20)
|
||||
|
||||
# 添加颜色条表示时间
|
||||
plt.colorbar(scatter, ax=ax, label='时间步')
|
||||
|
||||
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
|
||||
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
|
||||
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
return fig, ax
|
||||
|
||||
def animate_trajectories(self, solver: ThreeBodySolver,
|
||||
save_path: Optional[str] = None,
|
||||
fps: int = 30,
|
||||
dpi: int = 100):
|
||||
"""
|
||||
创建轨迹动画(需要额外安装 matplotlib.animation)
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
save_path: 保存路径(如果为None则显示动画)
|
||||
fps: 帧率
|
||||
dpi: 分辨率
|
||||
"""
|
||||
try:
|
||||
from matplotlib.animation import FuncAnimation
|
||||
except ImportError:
|
||||
print("错误:需要安装 matplotlib.animation 模块")
|
||||
return
|
||||
|
||||
trajectories = solver.get_trajectories()
|
||||
if any(len(traj) < 2 for traj in trajectories):
|
||||
print("错误:轨迹数据不足,无法创建动画")
|
||||
return
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# 初始化图形元素
|
||||
lines = []
|
||||
points = []
|
||||
colors = ['red', 'green', 'blue']
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 创建轨迹线
|
||||
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
lines.append(line)
|
||||
|
||||
# 创建当前位置点
|
||||
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
|
||||
markeredgecolor='black', markeredgewidth=1)
|
||||
points.append(point)
|
||||
|
||||
# 设置坐标轴范围
|
||||
all_points = np.vstack(trajectories)
|
||||
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
|
||||
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
|
||||
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
|
||||
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
|
||||
|
||||
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
|
||||
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
|
||||
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
|
||||
|
||||
ax.set_xlabel('X (AU)', fontsize=12)
|
||||
ax.set_ylabel('Y (AU)', fontsize=12)
|
||||
ax.set_zlabel('Z (AU)', fontsize=12)
|
||||
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 动画更新函数
|
||||
def update(frame):
|
||||
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
|
||||
# 更新轨迹线(显示到当前帧)
|
||||
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
|
||||
line.set_3d_properties(traj[:frame+1, 2])
|
||||
|
||||
# 更新当前位置点
|
||||
if frame < len(traj):
|
||||
point.set_data([traj[frame, 0]], [traj[frame, 1]])
|
||||
point.set_3d_properties([traj[frame, 2]])
|
||||
|
||||
return lines + points
|
||||
|
||||
# 创建动画
|
||||
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
|
||||
interval=1000/fps, blit=True)
|
||||
|
||||
if save_path:
|
||||
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
|
||||
print(f"动画已保存到: {save_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
return anim
|
||||
|
||||
def show(self):
|
||||
"""显示所有图形"""
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
def save_figure(self, filename: str, dpi: int = 300):
|
||||
"""保存当前图形"""
|
||||
if self.fig is not None:
|
||||
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
|
||||
print(f"图形已保存到: {filename}")
|
||||
else:
|
||||
print("警告:没有活动的图形可以保存")
|
||||
Reference in New Issue
Block a user