Files
dison0331-ThinkPad 8c8ad9fe07 first
pc-1
2026-03-11 21:32:58 +08:00

119 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数值积分器模块提供RK4方法求解微分方程
"""
import numpy as np
from typing import Callable, Tuple, List
from .particle import Particle
class RK4Integrator:
"""四阶龙格-库塔法积分器"""
def __init__(self, dt: float = 0.001):
"""
初始化积分器
参数:
dt: 时间步长(年)
"""
self.dt = dt
def step(self, particles: List[Particle],
acceleration_func: Callable[[List[Particle]], List[np.ndarray]]) -> List[Particle]:
"""
执行一个时间步的积分
参数:
particles: 质点列表
acceleration_func: 计算加速度的函数
返回:
更新后的质点列表
"""
# 保存当前状态
current_positions = [p.position.copy() for p in particles]
current_velocities = [p.velocity.copy() for p in particles]
# RK4步骤1: k1
k1_v = acceleration_func(particles)
k1_r = current_velocities
# 临时更新位置和速度用于k2计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + 0.5 * self.dt * k1_r[i]
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k1_v[i]
temp_particles.append(temp_p)
# RK4步骤2: k2
k2_v = acceleration_func(temp_particles)
k2_r = [p.velocity for p in temp_particles]
# 临时更新位置和速度用于k3计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + 0.5 * self.dt * k2_r[i]
temp_p.velocity = current_velocities[i] + 0.5 * self.dt * k2_v[i]
temp_particles.append(temp_p)
# RK4步骤3: k3
k3_v = acceleration_func(temp_particles)
k3_r = [p.velocity for p in temp_particles]
# 临时更新位置和速度用于k4计算
temp_particles = []
for i, p in enumerate(particles):
temp_p = p.copy()
temp_p.position = current_positions[i] + self.dt * k3_r[i]
temp_p.velocity = current_velocities[i] + self.dt * k3_v[i]
temp_particles.append(temp_p)
# RK4步骤4: k4
k4_v = acceleration_func(temp_particles)
k4_r = [p.velocity for p in temp_particles]
# 组合所有k值计算最终更新
new_particles = []
for i, p in enumerate(particles):
# 计算新的速度
new_velocity = (current_velocities[i] +
self.dt / 6.0 * (k1_v[i] + 2*k2_v[i] + 2*k3_v[i] + k4_v[i]))
# 计算新的位置
new_position = (current_positions[i] +
self.dt / 6.0 * (k1_r[i] + 2*k2_r[i] + 2*k3_r[i] + k4_r[i]))
# 创建新的质点对象
new_p = p.copy()
new_p.update(new_position, new_velocity)
new_particles.append(new_p)
return new_particles
def integrate(self, particles: List[Particle],
acceleration_func: Callable[[List[Particle]], List[np.ndarray]],
steps: int) -> List[Particle]:
"""
执行多步积分
参数:
particles: 初始质点列表
acceleration_func: 计算加速度的函数
steps: 积分步数
返回:
积分结束后的质点列表
"""
current_particles = particles
for step in range(steps):
current_particles = self.step(current_particles, acceleration_func)
# 每1000步打印进度
if step % 1000 == 0:
print(f"积分进度: {step}/{steps}")
return current_particles