""" 三体问题求解器测试 """ import numpy as np import pytest import sys import os # 添加父目录到路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from three_body_problem import Particle, ThreeBodySolver, ThreeBodyConfig class TestParticle: """测试质点类""" def test_particle_creation(self): """测试质点创建""" p = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Test") assert p.mass == 1.0 assert np.allclose(p.position, [1.0, 2.0, 3.0]) assert np.allclose(p.velocity, [0.1, 0.2, 0.3]) assert p.name == "Test" assert len(p.position_history) == 0 assert len(p.velocity_history) == 0 def test_particle_update(self): """测试质点状态更新""" p = Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) new_position = np.array([1.0, 2.0, 3.0]) new_velocity = np.array([0.1, 0.2, 0.3]) p.update(new_position, new_velocity) assert np.allclose(p.position, new_position) assert np.allclose(p.velocity, new_velocity) assert len(p.position_history) == 1 assert len(p.velocity_history) == 1 assert np.allclose(p.position_history[0], [0.0, 0.0, 0.0]) assert np.allclose(p.velocity_history[0], [0.0, 0.0, 0.0]) def test_particle_energy(self): """测试质点动能计算""" p = Particle(mass=2.0, position=[0.0, 0.0, 0.0], velocity=[1.0, 2.0, 3.0]) # 动能 = 0.5 * m * v^2 expected_energy = 0.5 * 2.0 * (1.0**2 + 2.0**2 + 3.0**2) calculated_energy = p.get_energy() assert np.isclose(calculated_energy, expected_energy) def test_particle_copy(self): """测试质点复制""" p1 = Particle(mass=1.0, position=[1.0, 2.0, 3.0], velocity=[0.1, 0.2, 0.3], name="Original") p2 = p1.copy() # 检查值相等 assert p2.mass == p1.mass assert np.allclose(p2.position, p1.position) assert np.allclose(p2.velocity, p1.velocity) assert p2.name == p1.name # 检查是深拷贝 p2.position[0] = 100.0 assert p1.position[0] == 1.0 # 原始对象不应改变 class TestThreeBodySolver: """测试三体问题求解器""" def test_solver_creation(self): """测试求解器创建""" particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]), Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]), Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-1.0, 0.0, 0.0]) ] solver = ThreeBodySolver(particles, dt=0.001) assert len(solver.particles) == 3 assert solver.dt == 0.001 assert solver.time == 0.0 def test_solver_wrong_number_of_particles(self): """测试错误数量的质点""" particles = [ Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]), Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) ] with pytest.raises(ValueError, match="三体问题需要恰好3个质点"): ThreeBodySolver(particles, dt=0.001) def test_acceleration_calculation(self): """测试加速度计算""" # 创建简单的测试配置:三个质点在等边三角形顶点 particles = [ Particle(mass=1.0, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]), Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]), Particle(mass=1.0, position=[0.5, np.sqrt(3)/2, 0.0], velocity=[0.0, 0.0, 0.0]) ] solver = ThreeBodySolver(particles, dt=0.001) # 计算加速度 accelerations = solver._calculate_accelerations(particles) # 检查加速度形状 assert len(accelerations) == 3 for acc in accelerations: assert acc.shape == (3,) # 由于对称性,第一个质点的加速度应该指向其他两个质点的方向 # 这里只检查计算是否成功,不检查具体值 def test_center_of_mass(self): """测试质心计算""" particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]), Particle(mass=2.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]), Particle(mass=3.0, position=[0.0, 1.0, 0.0], velocity=[0.0, 0.0, 0.0]) ] solver = ThreeBodySolver(particles, dt=0.001) com = solver.get_center_of_mass() # 计算期望的质心 total_mass = 1.0 + 2.0 + 3.0 expected_com = np.array([ (1.0*1.0 + 2.0*(-1.0) + 3.0*0.0) / total_mass, (1.0*0.0 + 2.0*0.0 + 3.0*1.0) / total_mass, (1.0*0.0 + 2.0*0.0 + 3.0*0.0) / total_mass ]) assert np.allclose(com, expected_com) def test_energy_calculation(self): """测试能量计算""" particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 1.0, 0.0]), Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -1.0, 0.0]), Particle(mass=1.0, position=[0.0, 0.0, 1.0], velocity=[0.0, 0.0, 0.0]) ] solver = ThreeBodySolver(particles, dt=0.001) energy = solver._calculate_total_energy() # 能量应该是有限的实数 assert np.isfinite(energy) assert isinstance(energy, float) def test_single_step(self): """测试单步积分""" particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 0.5, 0.0]), Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -0.5, 0.0]), Particle(mass=0.001, position=[0.0, 1.0, 0.0], velocity=[-0.5, 0.0, 0.0]) ] solver = ThreeBodySolver(particles, dt=0.001) # 保存初始位置 initial_positions = [p.position.copy() for p in particles] # 执行单步积分 new_particles = solver.step() # 检查位置是否改变 for i, (old_pos, new_particle) in enumerate(zip(initial_positions, new_particles)): assert not np.allclose(old_pos, new_particle.position) # 检查时间是否增加 assert solver.time == 0.001 def test_conservation_laws(self): """测试守恒定律(动量、角动量、能量)""" # 使用8字形轨道配置(应该是稳定的) particles = ThreeBodyConfig.create_figure8_config() solver = ThreeBodySolver(particles, dt=0.0001) # 模拟很短时间 solver.simulate(total_time=0.1, progress_interval=100) # 计算守恒误差 momentum_error, angular_momentum_error, energy_error = solver.get_conservation_errors() # 误差应该很小 assert momentum_error < 1e-10 assert angular_momentum_error < 1e-10 assert energy_error < 1e-5 # 能量守恒要求较低,因为数值误差 class TestThreeBodyConfig: """测试配置管理""" def test_figure8_config(self): """测试8字形轨道配置""" particles = ThreeBodyConfig.create_figure8_config() assert len(particles) == 3 assert all(p.mass == 1.0 for p in particles) # 检查总动量接近零(8字形轨道应该满足) total_momentum = np.zeros(3) for p in particles: total_momentum += p.mass * p.velocity assert np.linalg.norm(total_momentum) < 1e-10 def test_lagrange_config(self): """测试拉格朗日点配置""" # 测试L4点 particles_l4 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=4) assert len(particles_l4) == 3 # 检查质量 assert particles_l4[0].mass == 1.0 # 太阳 assert particles_l4[1].mass == 3e-6 # 地球 assert particles_l4[2].mass == 1e-8 # 测试质点 # 测试L5点 particles_l5 = ThreeBodyConfig.create_lagrange_point_config(lagrange_point=5) assert len(particles_l5) == 3 # 检查位置(L4和L5应该在等边三角形顶点) r_l4 = particles_l4[2].position r_l5 = particles_l5[2].position assert np.allclose(r_l4, [0.5, np.sqrt(3)/2, 0.0]) assert np.allclose(r_l5, [0.5, -np.sqrt(3)/2, 0.0]) def test_random_config(self): """测试随机配置""" particles = ThreeBodyConfig.create_random_config() assert len(particles) == 3 assert all(0.5 <= p.mass <= 2.0 for p in particles) # 检查总动量接近零(随机配置会调整速度使总动量为零) total_momentum = np.zeros(3) for p in particles: total_momentum += p.mass * p.velocity assert np.linalg.norm(total_momentum) < 1e-10 def test_custom_config(self): """测试自定义配置""" config_dict = { 'particle_1': { 'mass': 1.0, 'position': [1.0, 0.0, 0.0], 'velocity': [0.0, 1.0, 0.0], 'name': 'Star A', 'color': 'red' }, 'particle_2': { 'mass': 2.0, 'position': [-1.0, 0.0, 0.0], 'velocity': [0.0, -0.5, 0.0], 'name': 'Star B', 'color': 'green' }, 'particle_3': { 'mass': 0.5, 'position': [0.0, 1.0, 0.0], 'velocity': [0.5, 0.0, 0.0], 'name': 'Star C', 'color': 'blue' } } particles = ThreeBodyConfig.create_custom_config(config_dict) assert len(particles) == 3 assert particles[0].name == 'Star A' assert particles[1].mass == 2.0 assert np.allclose(particles[2].position, [0.0, 1.0, 0.0]) def test_integration_accuracy(): """测试数值积分精度""" # 使用简单的二体问题测试(第三个体质量很小) particles = [ Particle(mass=1.0, position=[1.0, 0.0, 0.0], velocity=[0.0, 2*np.pi, 0.0]), Particle(mass=1.0, position=[-1.0, 0.0, 0.0], velocity=[0.0, -2*np.pi, 0.0]), Particle(mass=1e-6, position=[0.0, 0.0, 0.0], velocity=[0.0, 0.0, 0.0]) # 很小的测试质点 ] # 测试不同时间步长 dts = [0.01, 0.005, 0.001, 0.0005] energy_errors = [] for dt in dts: solver = ThreeBodySolver([p.copy() for p in particles], dt=dt) solver.simulate(total_time=1.0, progress_interval=10000) _, _, energy_error = solver.get_conservation_errors() energy_errors.append(energy_error) # 检查误差随步长减小而减小(四阶方法) for i in range(len(energy_errors)-1): # 误差应该大致按 dt^4 减小 error_ratio = energy_errors[i] / energy_errors[i+1] dt_ratio = (dts[i] / dts[i+1])**4 # 允许一定的误差范围 assert 0.1 < error_ratio / dt_ratio < 10.0 if __name__ == "__main__": # 运行所有测试 print("运行三体问题求解器测试...") # 创建测试实例 test_particle = TestParticle() test_solver = TestThreeBodySolver() test_config = TestThreeBodyConfig() # 运行粒子测试 print("\n1. 测试质点类:") test_particle.test_particle_creation() print(" ✓ test_particle_creation 通过") test_particle.test_particle_update() print(" ✓ test_particle_update 通过") test_particle.test_particle_energy() print(" ✓ test_particle_energy 通过") test_particle.test_particle_copy() print(" ✓ test_particle_copy 通过") # 运行求解器测试 print("\n2. 测试求解器类:") test_solver.test_solver_creation() print(" ✓ test_solver_creation 通过") try: test_solver.test_solver_wrong_number_of_particles() print(" ✗ test_solver_wrong_number_of_particles 应该抛出异常") except ValueError: print(" ✓ test_solver_wrong_number_of_particles 通过") test_solver.test_acceleration_calculation() print(" ✓ test_acceleration_calculation 通过") test_solver.test_center_of_mass() print(" ✓ test_center_of_mass 通过") test_solver.test_energy_calculation() print(" ✓ test_energy_calculation 通过") test_solver.test_single_step() print(" ✓ test_single_step 通过") test_solver.test_conservation_laws() print(" ✓ test_conservation_laws 通过") # 运行配置测试 print("\n3. 测试配置管理:") test_config.test_figure8_config() print(" ✓ test_figure8_config 通过") test_config.test_lagrange_config() print(" ✓ test_lagrange_config 通过") test_config.test_random_config() print(" ✓ test_random_config 通过") test_config.test_custom_config() print(" ✓ test_custom_config 通过") # 运行积分精度测试 print("\n4. 测试数值积分精度:") test_integration_accuracy() print(" ✓ test_integration_accuracy 通过") print("\n" + "="*60) print("所有测试通过!") print("="*60)