119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
"""
|
||
数值积分器模块,提供RK4方法求解微分方程
|
||
"""
|
||
|
||
import numpy as np
|
||
from typing import Callable, Tuple, List
|
||
from .particle import Particle
|
||
|
||
|
||
class RK4Integrator:
|
||
"""四阶龙格-库塔法积分器"""
|
||
|
||
def __init__(self, dt: float = 0.001):
|
||
"""
|
||
初始化积分器
|
||
|
||
参数:
|
||
dt: 时间步长(年)
|
||
"""
|
||
self.dt = dt
|
||
|
||
def step(self, particles: List[Particle],
|
||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]]) -> List[Particle]:
|
||
"""
|
||
执行一个时间步的积分
|
||
|
||
参数:
|
||
particles: 质点列表
|
||
acceleration_func: 计算加速度的函数
|
||
|
||
返回:
|
||
更新后的质点列表
|
||
"""
|
||
# 保存当前状态
|
||
current_positions = [p.position.copy() for p in particles]
|
||
current_velocities = [p.velocity.copy() for p in particles]
|
||
|
||
# RK4步骤1: k1
|
||
k1_v = acceleration_func(particles)
|
||
k1_r = current_velocities
|
||
|
||
# 临时更新位置和速度用于k2计算
|
||
temp_particles = []
|
||
for i, p in enumerate(particles):
|
||
temp_p = p.copy()
|
||
temp_p.position = current_positions[i] + 0.5 * self.dt * k1_r[i]
|
||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k1_v[i]
|
||
temp_particles.append(temp_p)
|
||
|
||
# RK4步骤2: k2
|
||
k2_v = acceleration_func(temp_particles)
|
||
k2_r = [p.velocity for p in temp_particles]
|
||
|
||
# 临时更新位置和速度用于k3计算
|
||
temp_particles = []
|
||
for i, p in enumerate(particles):
|
||
temp_p = p.copy()
|
||
temp_p.position = current_positions[i] + 0.5 * self.dt * k2_r[i]
|
||
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k2_v[i]
|
||
temp_particles.append(temp_p)
|
||
|
||
# RK4步骤3: k3
|
||
k3_v = acceleration_func(temp_particles)
|
||
k3_r = [p.velocity for p in temp_particles]
|
||
|
||
# 临时更新位置和速度用于k4计算
|
||
temp_particles = []
|
||
for i, p in enumerate(particles):
|
||
temp_p = p.copy()
|
||
temp_p.position = current_positions[i] + self.dt * k3_r[i]
|
||
temp_p.velocity = current_velocities[i] + self.dt * k3_v[i]
|
||
temp_particles.append(temp_p)
|
||
|
||
# RK4步骤4: k4
|
||
k4_v = acceleration_func(temp_particles)
|
||
k4_r = [p.velocity for p in temp_particles]
|
||
|
||
# 组合所有k值计算最终更新
|
||
new_particles = []
|
||
for i, p in enumerate(particles):
|
||
# 计算新的速度
|
||
new_velocity = (current_velocities[i] +
|
||
self.dt / 6.0 * (k1_v[i] + 2*k2_v[i] + 2*k3_v[i] + k4_v[i]))
|
||
|
||
# 计算新的位置
|
||
new_position = (current_positions[i] +
|
||
self.dt / 6.0 * (k1_r[i] + 2*k2_r[i] + 2*k3_r[i] + k4_r[i]))
|
||
|
||
# 创建新的质点对象
|
||
new_p = p.copy()
|
||
new_p.update(new_position, new_velocity)
|
||
new_particles.append(new_p)
|
||
|
||
return new_particles
|
||
|
||
def integrate(self, particles: List[Particle],
|
||
acceleration_func: Callable[[List[Particle]], List[np.ndarray]],
|
||
steps: int) -> List[Particle]:
|
||
"""
|
||
执行多步积分
|
||
|
||
参数:
|
||
particles: 初始质点列表
|
||
acceleration_func: 计算加速度的函数
|
||
steps: 积分步数
|
||
|
||
返回:
|
||
积分结束后的质点列表
|
||
"""
|
||
current_particles = particles
|
||
|
||
for step in range(steps):
|
||
current_particles = self.step(current_particles, acceleration_func)
|
||
|
||
# 每1000步打印进度
|
||
if step % 1000 == 0:
|
||
print(f"积分进度: {step}/{steps} 步")
|
||
|
||
return current_particles |