Files
dison0331-ThinkPad 8c8ad9fe07 first
pc-1
2026-03-11 21:32:58 +08:00

385 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
三体问题求解器测试
"""
import numpy as np
import pytest
import sys
import os
# 添加父目录到路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig
class TestParticle:
"""测试质点类"""
def test_particle_creation(self):
"""测试质点创建"""
p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test")
assert p.mass == 1.0
assert np.allclose(p.position, [1.0, 2.0, 3.0])
assert np.allclose(p.velocity, [0.1, 0.2, 0.3])
assert p.name == "Test"
assert len(p.position_history) == 0
assert len(p.velocity_history) == 0
def test_particle_update(self):
"""测试质点状态更新"""
p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
new_position = np.array([1.0, 2.0, 3.0])
new_velocity = np.array([0.1, 0.2, 0.3])
p.update(new_position, new_velocity)
assert np.allclose(p.position, new_position)
assert np.allclose(p.velocity, new_velocity)
assert len(p.position_history) == 1
assert len(p.velocity_history) == 1
assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0])
assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0])
def test_particle_energy(self):
"""测试质点动能计算"""
p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0])
# 动能 = 0.5 * m * v^2
expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2)
calculated_energy = p.get_energy()
assert np.isclose(calculated_energy, expected_energy)
def test_particle_copy(self):
"""测试质点复制"""
p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original")
p2 = p1.copy()
# 检查值相等
assert p2.mass == p1.mass
assert np.allclose(p2.position, p1.position)
assert np.allclose(p2.velocity, p1.velocity)
assert p2.name == p1.name
# 检查是深拷贝
p2.position[0] = 100.0
assert p1.position[0] == 1.0 # 原始对象不应改变
class TestThreeBodySolver:
"""测试三体问题求解器"""
def test_solver_creation(self):
"""测试求解器创建"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
assert len(solver.particles) == 3
assert solver.dt == 0.001
assert solver.time == 0.0
def test_solver_wrong_number_of_particles(self):
"""测试错误数量的质点"""
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
with pytest.raises(ValueError, match="三体问题需要恰好3个质点"):
ThreeBodySolver(particles, dt=0.001)
def test_acceleration_calculation(self):
"""测试加速度计算"""
# 创建简单的测试配置:三个质点在等边三角形顶点
particles = [
Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 计算加速度
accelerations = solver._calculate_accelerations(particles)
# 检查加速度形状
assert len(accelerations) == 3
for acc in accelerations:
assert acc.shape == (3,)
# 由于对称性,第一个质点的加速度应该指向其他两个质点的方向
# 这里只检查计算是否成功,不检查具体值
def test_center_of_mass(self):
"""测试质心计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]),
Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
com = solver.get_center_of_mass()
# 计算期望的质心
total_mass = 1.0 + 2.0 + 3.0
expected_com = np.array([
(1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass,
(1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass
])
assert np.allclose(com, expected_com)
def test_energy_calculation(self):
"""测试能量计算"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]),
Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
energy = solver._calculate_total_energy()
# 能量应该是有限的实数
assert np.isfinite(energy)
assert isinstance(energy, float)
def test_single_step(self):
"""测试单步积分"""
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]),
Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0])
]
solver = ThreeBodySolver(particles, dt=0.001)
# 保存初始位置
initial_positions = [p.position.copy() for p in particles]
# 执行单步积分
new_particles = solver.step()
# 检查位置是否改变
for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)):
assert not np.allclose(old_pos, new_particle.position)
# 检查时间是否增加
assert solver.time == 0.001
def test_conservation_laws(self):
"""测试守恒定律(动量、角动量、能量)"""
# 使用8字形轨道配置应该是稳定的
particles = ThreeBodyConfig.create_figure8_config()
solver = ThreeBodySolver(particles, dt=0.0001)
# 模拟很短时间
solver.simulate(total_time=0.1, progress_interval=100)
# 计算守恒误差
momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors()
# 误差应该很小
assert momentum_error < 1e-10
assert angular_momentum_error < 1e-10
assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差
class TestThreeBodyConfig:
"""测试配置管理"""
def test_figure8_config(self):
"""测试8字形轨道配置"""
particles = ThreeBodyConfig.create_figure8_config()
assert len(particles) == 3
assert all(p.mass == 1.0 for p in particles)
# 检查总动量接近零8字形轨道应该满足
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_lagrange_config(self):
"""测试拉格朗日点配置"""
# 测试L4点
particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4)
assert len(particles_l4) == 3
# 检查质量
assert particles_l4[0].mass == 1.0 # 太阳
assert particles_l4[1].mass == 3e-6 # 地球
assert particles_l4[2].mass == 1e-8 # 测试质点
# 测试L5点
particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5)
assert len(particles_l5) == 3
# 检查位置L4和L5应该在等边三角形顶点
r_l4 = particles_l4[2].position
r_l5 = particles_l5[2].position
assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0])
assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0])
def test_random_config(self):
"""测试随机配置"""
particles = ThreeBodyConfig.create_random_config()
assert len(particles) == 3
assert all(0.5 <= p.mass <= 2.0 for p in particles)
# 检查总动量接近零(随机配置会调整速度使总动量为零)
total_momentum = np.zeros(3)
for p in particles:
total_momentum += p.mass * p.velocity
assert np.linalg.norm(total_momentum) < 1e-10
def test_custom_config(self):
"""测试自定义配置"""
config_dict = {
'particle_1': {
'mass': 1.0,
'position': [1.0, 0.0, 0.0],
'velocity': [0.0, 1.0, 0.0],
'name': 'Star A',
'color': 'red'
},
'particle_2': {
'mass': 2.0,
'position': [-1.0, 0.0, 0.0],
'velocity': [0.0, -0.5, 0.0],
'name': 'Star B',
'color': 'green'
},
'particle_3': {
'mass': 0.5,
'position': [0.0, 1.0, 0.0],
'velocity': [0.5, 0.0, 0.0],
'name': 'Star C',
'color': 'blue'
}
}
particles = ThreeBodyConfig.create_custom_config(config_dict)
assert len(particles) == 3
assert particles[0].name == 'Star A'
assert particles[1].mass == 2.0
assert np.allclose(particles[2].position, [0.0, 1.0, 0.0])
def test_integration_accuracy():
"""测试数值积分精度"""
# 使用简单的二体问题测试(第三个体质量很小)
particles = [
Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]),
Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]),
Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点
]
# 测试不同时间步长
dts = [0.01, 0.005, 0.001, 0.0005]
energy_errors = []
for dt in dts:
solver = ThreeBodySolver([p.copy() for p in particles], dt=dt)
solver.simulate(total_time=1.0, progress_interval=10000)
_, _, energy_error = solver.get_conservation_errors()
energy_errors.append(energy_error)
# 检查误差随步长减小而减小(四阶方法)
for i in range(len(energy_errors)-1):
# 误差应该大致按 dt^4 减小
error_ratio = energy_errors[i] / energy_errors[i+1]
dt_ratio = (dts[i] / dts[i+1])**4
# 允许一定的误差范围
assert 0.1 < error_ratio / dt_ratio < 10.0
if __name__ == "__main__":
# 运行所有测试
print("运行三体问题求解器测试...")
# 创建测试实例
test_particle = TestParticle()
test_solver = TestThreeBodySolver()
test_config = TestThreeBodyConfig()
# 运行粒子测试
print("\n1. 测试质点类:")
test_particle.test_particle_creation()
print(" ✓ test_particle_creation 通过")
test_particle.test_particle_update()
print(" ✓ test_particle_update 通过")
test_particle.test_particle_energy()
print(" ✓ test_particle_energy 通过")
test_particle.test_particle_copy()
print(" ✓ test_particle_copy 通过")
# 运行求解器测试
print("\n2. 测试求解器类:")
test_solver.test_solver_creation()
print(" ✓ test_solver_creation 通过")
try:
test_solver.test_solver_wrong_number_of_particles()
print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常")
except ValueError:
print(" ✓ test_solver_wrong_number_of_particles 通过")
test_solver.test_acceleration_calculation()
print(" ✓ test_acceleration_calculation 通过")
test_solver.test_center_of_mass()
print(" ✓ test_center_of_mass 通过")
test_solver.test_energy_calculation()
print(" ✓ test_energy_calculation 通过")
test_solver.test_single_step()
print(" ✓ test_single_step 通过")
test_solver.test_conservation_laws()
print(" ✓ test_conservation_laws 通过")
# 运行配置测试
print("\n3. 测试配置管理:")
test_config.test_figure8_config()
print(" ✓ test_figure8_config 通过")
test_config.test_lagrange_config()
print(" ✓ test_lagrange_config 通过")
test_config.test_random_config()
print(" ✓ test_random_config 通过")
test_config.test_custom_config()
print(" ✓ test_custom_config 通过")
# 运行积分精度测试
print("\n4. 测试数值积分精度:")
test_integration_accuracy()
print(" ✓ test_integration_accuracy 通过")
print("\n" + "="*60)
print("所有测试通过!")
print("="*60)