321 lines
12 KiB
Python
321 lines
12 KiB
Python
"""
|
||
三体问题可视化模块
|
||
"""
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from mpl_toolkits.mplot3d import Axes3D
|
||
from typing import List, Optional, Tuple
|
||
from .particle import Particle
|
||
from .solver import ThreeBodySolver
|
||
|
||
|
||
class ThreeBodyVisualizer:
|
||
"""三体问题可视化类"""
|
||
|
||
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
|
||
"""
|
||
初始化可视化器
|
||
|
||
参数:
|
||
figsize: 图形大小
|
||
"""
|
||
self.figsize = figsize
|
||
self.fig = None
|
||
self.ax = None
|
||
|
||
def create_3d_plot(self):
|
||
"""创建3D图形"""
|
||
self.fig = plt.figure(figsize=self.figsize)
|
||
self.ax = self.fig.add_subplot(111, projection='3d')
|
||
|
||
def plot_trajectories(self, solver: ThreeBodySolver,
|
||
show_current_positions: bool = True,
|
||
show_com: bool = True,
|
||
title: str = "三体问题轨迹"):
|
||
"""
|
||
绘制三体运动轨迹
|
||
|
||
参数:
|
||
solver: 三体问题求解器
|
||
show_current_positions: 是否显示当前位置
|
||
show_com: 是否显示质心
|
||
title: 图形标题
|
||
"""
|
||
if self.ax is None:
|
||
self.create_3d_plot()
|
||
|
||
trajectories = solver.get_trajectories()
|
||
|
||
# 绘制每个质点的轨迹
|
||
colors = ['red', 'green', 'blue']
|
||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||
color = particle.color if particle.color else colors[i % len(colors)]
|
||
label = particle.name if particle.name else f"质点 {i+1}"
|
||
|
||
# 绘制轨迹线
|
||
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
|
||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||
|
||
# 绘制当前位置
|
||
if show_current_positions and len(traj) > 0:
|
||
current_pos = traj[-1]
|
||
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
|
||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||
|
||
# 绘制质心
|
||
if show_com:
|
||
com = solver.get_center_of_mass()
|
||
self.ax.scatter(com[0], com[1], com[2],
|
||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||
|
||
# 设置图形属性
|
||
self.ax.set_xlabel('X (AU)', fontsize=12)
|
||
self.ax.set_ylabel('Y (AU)', fontsize=12)
|
||
self.ax.set_zlabel('Z (AU)', fontsize=12)
|
||
self.ax.set_title(title, fontsize=14, fontweight='bold')
|
||
self.ax.legend(fontsize=10)
|
||
self.ax.grid(True, alpha=0.3)
|
||
|
||
# 设置等比例坐标轴
|
||
self.ax.set_aspect('auto')
|
||
|
||
def plot_2d_projection(self, solver: ThreeBodySolver,
|
||
projection: str = 'xy',
|
||
show_current_positions: bool = True,
|
||
show_com: bool = True,
|
||
title: str = "三体问题轨迹 (2D投影)"):
|
||
"""
|
||
绘制2D投影
|
||
|
||
参数:
|
||
solver: 三体问题求解器
|
||
projection: 投影平面 ('xy', 'xz', 'yz')
|
||
show_current_positions: 是否显示当前位置
|
||
show_com: 是否显示质心
|
||
title: 图形标题
|
||
"""
|
||
if projection not in ['xy', 'xz', 'yz']:
|
||
raise ValueError("projection 必须是 'xy', 'xz' 或 'yz'")
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
trajectories = solver.get_trajectories()
|
||
|
||
# 确定坐标轴索引
|
||
if projection == 'xy':
|
||
x_idx, y_idx = 0, 1
|
||
x_label, y_label = 'X (AU)', 'Y (AU)'
|
||
elif projection == 'xz':
|
||
x_idx, y_idx = 0, 2
|
||
x_label, y_label = 'X (AU)', 'Z (AU)'
|
||
else: # 'yz'
|
||
x_idx, y_idx = 1, 2
|
||
x_label, y_label = 'Y (AU)', 'Z (AU)'
|
||
|
||
# 绘制每个质点的轨迹
|
||
colors = ['red', 'green', 'blue']
|
||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||
color = particle.color if particle.color else colors[i % len(colors)]
|
||
label = particle.name if particle.name else f"质点 {i+1}"
|
||
|
||
# 绘制轨迹线
|
||
ax.plot(traj[:, x_idx], traj[:, y_idx],
|
||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||
|
||
# 绘制当前位置
|
||
if show_current_positions and len(traj) > 0:
|
||
current_pos = traj[-1]
|
||
ax.scatter(current_pos[x_idx], current_pos[y_idx],
|
||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||
|
||
# 绘制质心
|
||
if show_com:
|
||
com = solver.get_center_of_mass()
|
||
ax.scatter(com[x_idx], com[y_idx],
|
||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||
|
||
# 设置图形属性
|
||
ax.set_xlabel(x_label, fontsize=12)
|
||
ax.set_ylabel(y_label, fontsize=12)
|
||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||
ax.legend(fontsize=10)
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_aspect('equal', adjustable='box')
|
||
|
||
return fig, ax
|
||
|
||
def plot_energy_conservation(self, solver: ThreeBodySolver,
|
||
time_points: Optional[np.ndarray] = None):
|
||
"""
|
||
绘制能量守恒图
|
||
|
||
参数:
|
||
solver: 三体问题求解器
|
||
time_points: 时间点数组(如果为None,则使用模拟时间)
|
||
"""
|
||
# 从历史记录中提取能量信息
|
||
# 注意:这里需要修改求解器以记录能量历史
|
||
# 暂时返回一个简单的图形
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
# 这里可以添加能量历史记录功能
|
||
ax.set_xlabel('时间 (年)', fontsize=12)
|
||
ax.set_ylabel('总能量', fontsize=12)
|
||
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
return fig, ax
|
||
|
||
def plot_phase_space(self, solver: ThreeBodySolver,
|
||
particle_index: int = 0,
|
||
dimension: str = 'x'):
|
||
"""
|
||
绘制相空间图(位置 vs 速度)
|
||
|
||
参数:
|
||
solver: 三体问题求解器
|
||
particle_index: 质点索引 (0, 1, 2)
|
||
dimension: 维度 ('x', 'y', 'z')
|
||
"""
|
||
if particle_index not in [0, 1, 2]:
|
||
raise ValueError("particle_index 必须是 0, 1, 或 2")
|
||
|
||
if dimension not in ['x', 'y', 'z']:
|
||
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
|
||
|
||
particle = solver.particles[particle_index]
|
||
trajectories = particle.get_trajectory()
|
||
|
||
if len(trajectories) < 2:
|
||
print("警告:轨迹数据不足,无法绘制相空间图")
|
||
return None, None
|
||
|
||
# 获取位置和速度历史
|
||
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
|
||
|
||
# 注意:速度历史需要从求解器中获取
|
||
# 暂时使用位置差分近似速度
|
||
if len(positions) > 1:
|
||
dt = solver.dt
|
||
velocities = np.gradient(positions, dt)
|
||
else:
|
||
velocities = np.zeros_like(positions)
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
# 绘制相空间轨迹
|
||
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
|
||
cmap='viridis', alpha=0.7, s=20)
|
||
|
||
# 添加颜色条表示时间
|
||
plt.colorbar(scatter, ax=ax, label='时间步')
|
||
|
||
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
|
||
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
|
||
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
|
||
fontsize=14, fontweight='bold')
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
return fig, ax
|
||
|
||
def animate_trajectories(self, solver: ThreeBodySolver,
|
||
save_path: Optional[str] = None,
|
||
fps: int = 30,
|
||
dpi: int = 100):
|
||
"""
|
||
创建轨迹动画(需要额外安装 matplotlib.animation)
|
||
|
||
参数:
|
||
solver: 三体问题求解器
|
||
save_path: 保存路径(如果为None则显示动画)
|
||
fps: 帧率
|
||
dpi: 分辨率
|
||
"""
|
||
try:
|
||
from matplotlib.animation import FuncAnimation
|
||
except ImportError:
|
||
print("错误:需要安装 matplotlib.animation 模块")
|
||
return
|
||
|
||
trajectories = solver.get_trajectories()
|
||
if any(len(traj) < 2 for traj in trajectories):
|
||
print("错误:轨迹数据不足,无法创建动画")
|
||
return
|
||
|
||
fig = plt.figure(figsize=(10, 8))
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
|
||
# 初始化图形元素
|
||
lines = []
|
||
points = []
|
||
colors = ['red', 'green', 'blue']
|
||
|
||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||
color = particle.color if particle.color else colors[i % len(colors)]
|
||
label = particle.name if particle.name else f"质点 {i+1}"
|
||
|
||
# 创建轨迹线
|
||
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||
lines.append(line)
|
||
|
||
# 创建当前位置点
|
||
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
|
||
markeredgecolor='black', markeredgewidth=1)
|
||
points.append(point)
|
||
|
||
# 设置坐标轴范围
|
||
all_points = np.vstack(trajectories)
|
||
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
|
||
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
|
||
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
|
||
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
|
||
|
||
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
|
||
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
|
||
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
|
||
|
||
ax.set_xlabel('X (AU)', fontsize=12)
|
||
ax.set_ylabel('Y (AU)', fontsize=12)
|
||
ax.set_zlabel('Z (AU)', fontsize=12)
|
||
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
|
||
ax.legend(fontsize=10)
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
# 动画更新函数
|
||
def update(frame):
|
||
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
|
||
# 更新轨迹线(显示到当前帧)
|
||
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
|
||
line.set_3d_properties(traj[:frame+1, 2])
|
||
|
||
# 更新当前位置点
|
||
if frame < len(traj):
|
||
point.set_data([traj[frame, 0]], [traj[frame, 1]])
|
||
point.set_3d_properties([traj[frame, 2]])
|
||
|
||
return lines + points
|
||
|
||
# 创建动画
|
||
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
|
||
interval=1000/fps, blit=True)
|
||
|
||
if save_path:
|
||
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
|
||
print(f"动画已保存到: {save_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
return anim
|
||
|
||
def show(self):
|
||
"""显示所有图形"""
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
def save_figure(self, filename: str, dpi: int = 300):
|
||
"""保存当前图形"""
|
||
if self.fig is not None:
|
||
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
|
||
print(f"图形已保存到: {filename}")
|
||
else:
|
||
print("警告:没有活动的图形可以保存") |