64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
"""
|
|
质点类,表示三体问题中的一个天体
|
|
"""
|
|
|
|
import numpy as np
|
|
from typing import Tuple, List
|
|
|
|
|
|
class Particle:
|
|
"""表示一个质点的类"""
|
|
|
|
def __init__(self, mass: float, position: np.ndarray, velocity: np.ndarray,
|
|
name: str = "", color: str = None):
|
|
"""
|
|
初始化质点
|
|
|
|
参数:
|
|
mass: 质量(太阳质量)
|
|
position: 位置向量 (x, y, z) [AU]
|
|
velocity: 速度向量 (vx, vy, vz) [AU/年]
|
|
name: 质点名称
|
|
color: 可视化颜色
|
|
"""
|
|
self.mass = mass
|
|
self.position = np.array(position, dtype=np.float64)
|
|
self.velocity = np.array(velocity, dtype=np.float64)
|
|
self.name = name
|
|
self.color = color
|
|
|
|
# 存储轨迹历史
|
|
self.position_history = []
|
|
self.velocity_history = []
|
|
|
|
def update(self, new_position: np.ndarray, new_velocity: np.ndarray):
|
|
"""更新质点的状态并记录历史"""
|
|
self.position_history.append(self.position.copy())
|
|
self.velocity_history.append(self.velocity.copy())
|
|
self.position = new_position
|
|
self.velocity = new_velocity
|
|
|
|
def get_trajectory(self) -> np.ndarray:
|
|
"""获取轨迹历史"""
|
|
if not self.position_history:
|
|
return np.array([self.position])
|
|
return np.array(self.position_history + [self.position])
|
|
|
|
def get_energy(self) -> float:
|
|
"""计算质点的动能"""
|
|
v_squared = np.sum(self.velocity**2)
|
|
return 0.5 * self.mass * v_squared
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"Particle(name='{self.name}', mass={self.mass:.3f}, "
|
|
f"position={self.position}, velocity={self.velocity})")
|
|
|
|
def copy(self) -> 'Particle':
|
|
"""创建质点的副本"""
|
|
return Particle(
|
|
mass=self.mass,
|
|
position=self.position.copy(),
|
|
velocity=self.velocity.copy(),
|
|
name=self.name,
|
|
color=self.color
|
|
) |