first
pc-1
This commit is contained in:
3
three_body_problem/tests/__init__.py
Normal file
3
three_body_problem/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
三体问题测试模块
|
||||
"""
|
||||
385
three_body_problem/tests/test_solver.py
Normal file
385
three_body_problem/tests/test_solver.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
三体问题求解器测试
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig
|
||||
|
||||
|
||||
class TestParticle:
|
||||
"""测试质点类"""
|
||||
|
||||
def test_particle_creation(self):
|
||||
"""测试质点创建"""
|
||||
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
|
||||
|
||||
assert p.mass == 1.0
|
||||
assert np.allclose(p.position, [1.0, 2.0, 3.0])
|
||||
assert np.allclose(p.velocity, [0.1, 0.2, 0.3])
|
||||
assert p.name == "Test"
|
||||
assert len(p.position_history) == 0
|
||||
assert len(p.velocity_history) == 0
|
||||
|
||||
def test_particle_update(self):
|
||||
"""测试质点状态更新"""
|
||||
p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
|
||||
new_position = np.array([1.0, 2.0, 3.0])
|
||||
new_velocity = np.array([0.1, 0.2, 0.3])
|
||||
|
||||
p.update(new_position, new_velocity)
|
||||
|
||||
assert np.allclose(p.position, new_position)
|
||||
assert np.allclose(p.velocity, new_velocity)
|
||||
assert len(p.position_history) == 1
|
||||
assert len(p.velocity_history) == 1
|
||||
assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0])
|
||||
assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0])
|
||||
|
||||
def test_particle_energy(self):
|
||||
"""测试质点动能计算"""
|
||||
p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0])
|
||||
|
||||
# 动能 = 0.5 * m * v^2
|
||||
expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2)
|
||||
calculated_energy = p.get_energy()
|
||||
|
||||
assert np.isclose(calculated_energy, expected_energy)
|
||||
|
||||
def test_particle_copy(self):
|
||||
"""测试质点复制"""
|
||||
p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original")
|
||||
p2 = p1.copy()
|
||||
|
||||
# 检查值相等
|
||||
assert p2.mass == p1.mass
|
||||
assert np.allclose(p2.position, p1.position)
|
||||
assert np.allclose(p2.velocity, p1.velocity)
|
||||
assert p2.name == p1.name
|
||||
|
||||
# 检查是深拷贝
|
||||
p2.position[0] = 100.0
|
||||
assert p1.position[0] == 1.0 # 原始对象不应改变
|
||||
|
||||
|
||||
class TestThreeBodySolver:
|
||||
"""测试三体问题求解器"""
|
||||
|
||||
def test_solver_creation(self):
|
||||
"""测试求解器创建"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
|
||||
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
assert len(solver.particles) == 3
|
||||
assert solver.dt == 0.001
|
||||
assert solver.time == 0.0
|
||||
|
||||
def test_solver_wrong_number_of_particles(self):
|
||||
"""测试错误数量的质点"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="三体问题需要恰好3个质点"):
|
||||
ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
def test_acceleration_calculation(self):
|
||||
"""测试加速度计算"""
|
||||
# 创建简单的测试配置:三个质点在等边三角形顶点
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 计算加速度
|
||||
accelerations = solver._calculate_accelerations(particles)
|
||||
|
||||
# 检查加速度形状
|
||||
assert len(accelerations) == 3
|
||||
for acc in accelerations:
|
||||
assert acc.shape == (3,)
|
||||
|
||||
# 由于对称性,第一个质点的加速度应该指向其他两个质点的方向
|
||||
# 这里只检查计算是否成功,不检查具体值
|
||||
|
||||
def test_center_of_mass(self):
|
||||
"""测试质心计算"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
|
||||
Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
com = solver.get_center_of_mass()
|
||||
|
||||
# 计算期望的质心
|
||||
total_mass = 1.0 + 2.0 + 3.0
|
||||
expected_com = np.array([
|
||||
(1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass,
|
||||
(1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass,
|
||||
(1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass
|
||||
])
|
||||
|
||||
assert np.allclose(com, expected_com)
|
||||
|
||||
def test_energy_calculation(self):
|
||||
"""测试能量计算"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
|
||||
Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
energy = solver._calculate_total_energy()
|
||||
|
||||
# 能量应该是有限的实数
|
||||
assert np.isfinite(energy)
|
||||
assert isinstance(energy, float)
|
||||
|
||||
def test_single_step(self):
|
||||
"""测试单步积分"""
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]),
|
||||
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0])
|
||||
]
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.001)
|
||||
|
||||
# 保存初始位置
|
||||
initial_positions = [p.position.copy() for p in particles]
|
||||
|
||||
# 执行单步积分
|
||||
new_particles = solver.step()
|
||||
|
||||
# 检查位置是否改变
|
||||
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
|
||||
assert not np.allclose(old_pos, new_particle.position)
|
||||
|
||||
# 检查时间是否增加
|
||||
assert solver.time == 0.001
|
||||
|
||||
def test_conservation_laws(self):
|
||||
"""测试守恒定律(动量、角动量、能量)"""
|
||||
# 使用8字形轨道配置(应该是稳定的)
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
solver = ThreeBodySolver(particles, dt=0.0001)
|
||||
|
||||
# 模拟很短时间
|
||||
solver.simulate(total_time=0.1, progress_interval=100)
|
||||
|
||||
# 计算守恒误差
|
||||
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
|
||||
|
||||
# 误差应该很小
|
||||
assert momentum_error < 1e-10
|
||||
assert angular_momentum_error < 1e-10
|
||||
assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差
|
||||
|
||||
|
||||
class TestThreeBodyConfig:
|
||||
"""测试配置管理"""
|
||||
|
||||
def test_figure8_config(self):
|
||||
"""测试8字形轨道配置"""
|
||||
particles = ThreeBodyConfig.create_figure8_config()
|
||||
|
||||
assert len(particles) == 3
|
||||
assert all(p.mass == 1.0 for p in particles)
|
||||
|
||||
# 检查总动量接近零(8字形轨道应该满足)
|
||||
total_momentum = np.zeros(3)
|
||||
for p in particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
|
||||
assert np.linalg.norm(total_momentum) < 1e-10
|
||||
|
||||
def test_lagrange_config(self):
|
||||
"""测试拉格朗日点配置"""
|
||||
# 测试L4点
|
||||
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
|
||||
assert len(particles_l4) == 3
|
||||
|
||||
# 检查质量
|
||||
assert particles_l4[0].mass == 1.0 # 太阳
|
||||
assert particles_l4[1].mass == 3e-6 # 地球
|
||||
assert particles_l4[2].mass == 1e-8 # 测试质点
|
||||
|
||||
# 测试L5点
|
||||
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
|
||||
assert len(particles_l5) == 3
|
||||
|
||||
# 检查位置(L4和L5应该在等边三角形顶点)
|
||||
r_l4 = particles_l4[2].position
|
||||
r_l5 = particles_l5[2].position
|
||||
|
||||
assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0])
|
||||
assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0])
|
||||
|
||||
def test_random_config(self):
|
||||
"""测试随机配置"""
|
||||
particles = ThreeBodyConfig.create_random_config()
|
||||
|
||||
assert len(particles) == 3
|
||||
assert all(0.5 <= p.mass <= 2.0 for p in particles)
|
||||
|
||||
# 检查总动量接近零(随机配置会调整速度使总动量为零)
|
||||
total_momentum = np.zeros(3)
|
||||
for p in particles:
|
||||
total_momentum += p.mass * p.velocity
|
||||
|
||||
assert np.linalg.norm(total_momentum) < 1e-10
|
||||
|
||||
def test_custom_config(self):
|
||||
"""测试自定义配置"""
|
||||
config_dict = {
|
||||
'particle_1': {
|
||||
'mass': 1.0,
|
||||
'position': [1.0, 0.0, 0.0],
|
||||
'velocity': [0.0, 1.0, 0.0],
|
||||
'name': 'Star A',
|
||||
'color': 'red'
|
||||
},
|
||||
'particle_2': {
|
||||
'mass': 2.0,
|
||||
'position': [-1.0, 0.0, 0.0],
|
||||
'velocity': [0.0, -0.5, 0.0],
|
||||
'name': 'Star B',
|
||||
'color': 'green'
|
||||
},
|
||||
'particle_3': {
|
||||
'mass': 0.5,
|
||||
'position': [0.0, 1.0, 0.0],
|
||||
'velocity': [0.5, 0.0, 0.0],
|
||||
'name': 'Star C',
|
||||
'color': 'blue'
|
||||
}
|
||||
}
|
||||
|
||||
particles = ThreeBodyConfig.create_custom_config(config_dict)
|
||||
|
||||
assert len(particles) == 3
|
||||
assert particles[0].name == 'Star A'
|
||||
assert particles[1].mass == 2.0
|
||||
assert np.allclose(particles[2].position, [0.0, 1.0, 0.0])
|
||||
|
||||
|
||||
def test_integration_accuracy():
|
||||
"""测试数值积分精度"""
|
||||
# 使用简单的二体问题测试(第三个体质量很小)
|
||||
particles = [
|
||||
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]),
|
||||
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]),
|
||||
Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点
|
||||
]
|
||||
|
||||
# 测试不同时间步长
|
||||
dts = [0.01, 0.005, 0.001, 0.0005]
|
||||
energy_errors = []
|
||||
|
||||
for dt in dts:
|
||||
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
|
||||
solver.simulate(total_time=1.0, progress_interval=10000)
|
||||
|
||||
_, _, energy_error = solver.get_conservation_errors()
|
||||
energy_errors.append(energy_error)
|
||||
|
||||
# 检查误差随步长减小而减小(四阶方法)
|
||||
for i in range(len(energy_errors)-1):
|
||||
# 误差应该大致按 dt^4 减小
|
||||
error_ratio = energy_errors[i] / energy_errors[i+1]
|
||||
dt_ratio = (dts[i] / dts[i+1])**4
|
||||
# 允许一定的误差范围
|
||||
assert 0.1 < error_ratio / dt_ratio < 10.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行所有测试
|
||||
print("运行三体问题求解器测试...")
|
||||
|
||||
# 创建测试实例
|
||||
test_particle = TestParticle()
|
||||
test_solver = TestThreeBodySolver()
|
||||
test_config = TestThreeBodyConfig()
|
||||
|
||||
# 运行粒子测试
|
||||
print("\n1. 测试质点类:")
|
||||
test_particle.test_particle_creation()
|
||||
print(" ✓ test_particle_creation 通过")
|
||||
|
||||
test_particle.test_particle_update()
|
||||
print(" ✓ test_particle_update 通过")
|
||||
|
||||
test_particle.test_particle_energy()
|
||||
print(" ✓ test_particle_energy 通过")
|
||||
|
||||
test_particle.test_particle_copy()
|
||||
print(" ✓ test_particle_copy 通过")
|
||||
|
||||
# 运行求解器测试
|
||||
print("\n2. 测试求解器类:")
|
||||
test_solver.test_solver_creation()
|
||||
print(" ✓ test_solver_creation 通过")
|
||||
|
||||
try:
|
||||
test_solver.test_solver_wrong_number_of_particles()
|
||||
print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常")
|
||||
except ValueError:
|
||||
print(" ✓ test_solver_wrong_number_of_particles 通过")
|
||||
|
||||
test_solver.test_acceleration_calculation()
|
||||
print(" ✓ test_acceleration_calculation 通过")
|
||||
|
||||
test_solver.test_center_of_mass()
|
||||
print(" ✓ test_center_of_mass 通过")
|
||||
|
||||
test_solver.test_energy_calculation()
|
||||
print(" ✓ test_energy_calculation 通过")
|
||||
|
||||
test_solver.test_single_step()
|
||||
print(" ✓ test_single_step 通过")
|
||||
|
||||
test_solver.test_conservation_laws()
|
||||
print(" ✓ test_conservation_laws 通过")
|
||||
|
||||
# 运行配置测试
|
||||
print("\n3. 测试配置管理:")
|
||||
test_config.test_figure8_config()
|
||||
print(" ✓ test_figure8_config 通过")
|
||||
|
||||
test_config.test_lagrange_config()
|
||||
print(" ✓ test_lagrange_config 通过")
|
||||
|
||||
test_config.test_random_config()
|
||||
print(" ✓ test_random_config 通过")
|
||||
|
||||
test_config.test_custom_config()
|
||||
print(" ✓ test_custom_config 通过")
|
||||
|
||||
# 运行积分精度测试
|
||||
print("\n4. 测试数值积分精度:")
|
||||
test_integration_accuracy()
|
||||
print(" ✓ test_integration_accuracy 通过")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("所有测试通过!")
|
||||
print("="*60)
|
||||
Reference in New Issue
Block a user