first
pc-1
This commit is contained in:
119
three_body_problem/integrator.py
Normal file
119
three_body_problem/integrator.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
数值积分器模块,提供RK4方法求解微分方程
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Callable, Tuple, List
|
||||
from .particle import Particle
|
||||
|
||||
|
||||
class RK4Integrator:
|
||||
"""四阶龙格-库塔法积分器"""
|
||||
|
||||
def __init__(self, dt: float = 0.001):
|
||||
"""
|
||||
初始化积分器
|
||||
|
||||
参数:
|
||||
dt: 时间步长(年)
|
||||
"""
|
||||
self.dt = dt
|
||||
|
||||
def step(self, particles: List[Particle],
|
||||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]]) -> List[Particle]:
|
||||
"""
|
||||
执行一个时间步的积分
|
||||
|
||||
参数:
|
||||
particles: 质点列表
|
||||
acceleration_func: 计算加速度的函数
|
||||
|
||||
返回:
|
||||
更新后的质点列表
|
||||
"""
|
||||
# 保存当前状态
|
||||
current_positions = [p.position.copy() for p in particles]
|
||||
current_velocities = [p.velocity.copy() for p in particles]
|
||||
|
||||
# RK4步骤1: k1
|
||||
k1_v = acceleration_func(particles)
|
||||
k1_r = current_velocities
|
||||
|
||||
# 临时更新位置和速度用于k2计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + 0.5 * self.dt * k1_r[i]
|
||||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k1_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤2: k2
|
||||
k2_v = acceleration_func(temp_particles)
|
||||
k2_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 临时更新位置和速度用于k3计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + 0.5 * self.dt * k2_r[i]
|
||||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k2_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤3: k3
|
||||
k3_v = acceleration_func(temp_particles)
|
||||
k3_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 临时更新位置和速度用于k4计算
|
||||
temp_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
temp_p = p.copy()
|
||||
temp_p.position = current_positions[i] + self.dt * k3_r[i]
|
||||
temp_p.velocity = current_velocities[i] + self.dt * k3_v[i]
|
||||
temp_particles.append(temp_p)
|
||||
|
||||
# RK4步骤4: k4
|
||||
k4_v = acceleration_func(temp_particles)
|
||||
k4_r = [p.velocity for p in temp_particles]
|
||||
|
||||
# 组合所有k值计算最终更新
|
||||
new_particles = []
|
||||
for i, p in enumerate(particles):
|
||||
# 计算新的速度
|
||||
new_velocity = (current_velocities[i] +
|
||||
self.dt / 6.0 * (k1_v[i] + 2*k2_v[i] + 2*k3_v[i] + k4_v[i]))
|
||||
|
||||
# 计算新的位置
|
||||
new_position = (current_positions[i] +
|
||||
self.dt / 6.0 * (k1_r[i] + 2*k2_r[i] + 2*k3_r[i] + k4_r[i]))
|
||||
|
||||
# 创建新的质点对象
|
||||
new_p = p.copy()
|
||||
new_p.update(new_position, new_velocity)
|
||||
new_particles.append(new_p)
|
||||
|
||||
return new_particles
|
||||
|
||||
def integrate(self, particles: List[Particle],
|
||||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]],
|
||||
steps: int) -> List[Particle]:
|
||||
"""
|
||||
执行多步积分
|
||||
|
||||
参数:
|
||||
particles: 初始质点列表
|
||||
acceleration_func: 计算加速度的函数
|
||||
steps: 积分步数
|
||||
|
||||
返回:
|
||||
积分结束后的质点列表
|
||||
"""
|
||||
current_particles = particles
|
||||
|
||||
for step in range(steps):
|
||||
current_particles = self.step(current_particles, acceleration_func)
|
||||
|
||||
# 每1000步打印进度
|
||||
if step % 1000 == 0:
|
||||
print(f"积分进度: {step}/{steps} 步")
|
||||
|
||||
return current_particles
|
||||
Reference in New Issue
Block a user