pc-1
This commit is contained in:
dison0331-ThinkPad
2026-03-11 21:32:58 +08:00
commit 8c8ad9fe07
29 changed files with 4005 additions and 0 deletions

View File

@@ -0,0 +1,385 @@
"""
三体问题求解器测试
"""
import numpy as np
import pytest
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig
class TestParticle:
"""测试质点类"""
def test_particle_creation(self):
"""测试质点创建"""
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
assert p.mass == 1.0
assert np.allclose(p.position, [1.0, 2.0, 3.0])
assert np.allclose(p.velocity, [0.1, 0.2, 0.3])
assert p.name == "Test"
assert len(p.position_history) == 0
assert len(p.velocity_history) == 0
def test_particle_update(self):
"""测试质点状态更新"""
p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
new_position = np.array([1.0, 2.0, 3.0])
new_velocity = np.array([0.1, 0.2, 0.3])
p.update(new_position, new_velocity)
assert np.allclose(p.position, new_position)
assert np.allclose(p.velocity, new_velocity)
assert len(p.position_history) == 1
assert len(p.velocity_history) == 1
assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0])
assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0])
def test_particle_energy(self):
"""测试质点动能计算"""
p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0])
# 动能 = 0.5 * m * v^2
expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2)
calculated_energy = p.get_energy()
assert np.isclose(calculated_energy, expected_energy)
def test_particle_copy(self):
"""测试质点复制"""
p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original")
p2 = p1.copy()
# 检查值相等
assert p2.mass == p1.mass
assert np.allclose(p2.position, p1.position)
assert np.allclose(p2.velocity, p1.velocity)
assert p2.name == p1.name
# 检查是深拷贝
p2.position[0] = 100.0
assert p1.position[0] == 1.0 # 原始对象不应改变
class TestThreeBodySolver:
"""测试三体问题求解器"""
def test_solver_creation(self):
"""测试求解器创建"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
assert len(solver.particles) == 3
assert solver.dt == 0.001
assert solver.time == 0.0
def test_solver_wrong_number_of_particles(self):
"""测试错误数量的质点"""
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
with pytest.raises(ValueError, match="三体问题需要恰好3个质点"):
ThreeBodySolver(particles, dt=0.001)
def test_acceleration_calculation(self):
"""测试加速度计算"""
# 创建简单的测试配置:三个质点在等边三角形顶点
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 计算加速度
accelerations = solver._calculate_accelerations(particles)
# 检查加速度形状
assert len(accelerations) == 3
for acc in accelerations:
assert acc.shape == (3,)
# 由于对称性,第一个质点的加速度应该指向其他两个质点的方向
# 这里只检查计算是否成功,不检查具体值
def test_center_of_mass(self):
"""测试质心计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
com = solver.get_center_of_mass()
# 计算期望的质心
total_mass = 1.0 + 2.0 + 3.0
expected_com = np.array([
(1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass
])
assert np.allclose(com, expected_com)
def test_energy_calculation(self):
"""测试能量计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
energy = solver._calculate_total_energy()
# 能量应该是有限的实数
assert np.isfinite(energy)
assert isinstance(energy, float)
def test_single_step(self):
"""测试单步积分"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 保存初始位置
initial_positions = [p.position.copy() for p in particles]
# 执行单步积分
new_particles = solver.step()
# 检查位置是否改变
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
assert not np.allclose(old_pos, new_particle.position)
# 检查时间是否增加
assert solver.time == 0.001
def test_conservation_laws(self):
"""测试守恒定律(动量、角动量、能量)"""
# 使用8字形轨道配置应该是稳定的
particles = ThreeBodyConfig.create_figure8_config()
solver = ThreeBodySolver(particles, dt=0.0001)
# 模拟很短时间
solver.simulate(total_time=0.1, progress_interval=100)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
# 误差应该很小
assert momentum_error < 1e-10
assert angular_momentum_error < 1e-10
assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差
class TestThreeBodyConfig:
"""测试配置管理"""
def test_figure8_config(self):
"""测试8字形轨道配置"""
particles = ThreeBodyConfig.create_figure8_config()
assert len(particles) == 3
assert all(p.mass == 1.0 for p in particles)
# 检查总动量接近零8字形轨道应该满足
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_lagrange_config(self):
"""测试拉格朗日点配置"""
# 测试L4点
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
assert len(particles_l4) == 3
# 检查质量
assert particles_l4[0].mass == 1.0 # 太阳
assert particles_l4[1].mass == 3e-6 # 地球
assert particles_l4[2].mass == 1e-8 # 测试质点
# 测试L5点
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
assert len(particles_l5) == 3
# 检查位置L4和L5应该在等边三角形顶点
r_l4 = particles_l4[2].position
r_l5 = particles_l5[2].position
assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0])
assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0])
def test_random_config(self):
"""测试随机配置"""
particles = ThreeBodyConfig.create_random_config()
assert len(particles) == 3
assert all(0.5 <= p.mass <= 2.0 for p in particles)
# 检查总动量接近零(随机配置会调整速度使总动量为零)
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_custom_config(self):
"""测试自定义配置"""
config_dict = {
'particle_1': {
'mass': 1.0,
'position': [1.0, 0.0, 0.0],
'velocity': [0.0, 1.0, 0.0],
'name': 'Star A',
'color': 'red'
},
'particle_2': {
'mass': 2.0,
'position': [-1.0, 0.0, 0.0],
'velocity': [0.0, -0.5, 0.0],
'name': 'Star B',
'color': 'green'
},
'particle_3': {
'mass': 0.5,
'position': [0.0, 1.0, 0.0],
'velocity': [0.5, 0.0, 0.0],
'name': 'Star C',
'color': 'blue'
}
}
particles = ThreeBodyConfig.create_custom_config(config_dict)
assert len(particles) == 3
assert particles[0].name == 'Star A'
assert particles[1].mass == 2.0
assert np.allclose(particles[2].position, [0.0, 1.0, 0.0])
def test_integration_accuracy():
"""测试数值积分精度"""
# 使用简单的二体问题测试(第三个体质量很小)
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]),
Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点
]
# 测试不同时间步长
dts = [0.01, 0.005, 0.001, 0.0005]
energy_errors = []
for dt in dts:
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=1.0, progress_interval=10000)
_, _, energy_error = solver.get_conservation_errors()
energy_errors.append(energy_error)
# 检查误差随步长减小而减小(四阶方法)
for i in range(len(energy_errors)-1):
# 误差应该大致按 dt^4 减小
error_ratio = energy_errors[i] / energy_errors[i+1]
dt_ratio = (dts[i] / dts[i+1])**4
# 允许一定的误差范围
assert 0.1 < error_ratio / dt_ratio < 10.0
if __name__ == "__main__":
# 运行所有测试
print("运行三体问题求解器测试...")
# 创建测试实例
test_particle = TestParticle()
test_solver = TestThreeBodySolver()
test_config = TestThreeBodyConfig()
# 运行粒子测试
print("\n1. 测试质点类:")
test_particle.test_particle_creation()
print(" ✓ test_particle_creation 通过")
test_particle.test_particle_update()
print(" ✓ test_particle_update 通过")
test_particle.test_particle_energy()
print(" ✓ test_particle_energy 通过")
test_particle.test_particle_copy()
print(" ✓ test_particle_copy 通过")
# 运行求解器测试
print("\n2. 测试求解器类:")
test_solver.test_solver_creation()
print(" ✓ test_solver_creation 通过")
try:
test_solver.test_solver_wrong_number_of_particles()
print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常")
except ValueError:
print(" ✓ test_solver_wrong_number_of_particles 通过")
test_solver.test_acceleration_calculation()
print(" ✓ test_acceleration_calculation 通过")
test_solver.test_center_of_mass()
print(" ✓ test_center_of_mass 通过")
test_solver.test_energy_calculation()
print(" ✓ test_energy_calculation 通过")
test_solver.test_single_step()
print(" ✓ test_single_step 通过")
test_solver.test_conservation_laws()
print(" ✓ test_conservation_laws 通过")
# 运行配置测试
print("\n3. 测试配置管理:")
test_config.test_figure8_config()
print(" ✓ test_figure8_config 通过")
test_config.test_lagrange_config()
print(" ✓ test_lagrange_config 通过")
test_config.test_random_config()
print(" ✓ test_random_config 通过")
test_config.test_custom_config()
print(" ✓ test_custom_config 通过")
# 运行积分精度测试
print("\n4. 测试数值积分精度:")
test_integration_accuracy()
print(" ✓ test_integration_accuracy 通过")
print("\n" + "="*60)
print("所有测试通过!")
print("="*60)