pc-1
This commit is contained in:
dison0331-ThinkPad
2026-03-11 21:32:58 +08:00
commit 8c8ad9fe07
29 changed files with 4005 additions and 0 deletions

View File

@@ -0,0 +1,321 @@
"""
三体问题可视化模块
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Optional, Tuple
from .particle import Particle
from .solver import ThreeBodySolver
class ThreeBodyVisualizer:
"""三体问题可视化类"""
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
"""
初始化可视化器
参数:
figsize: 图形大小
"""
self.figsize = figsize
self.fig = None
self.ax = None
def create_3d_plot(self):
"""创建3D图形"""
self.fig = plt.figure(figsize=self.figsize)
self.ax = self.fig.add_subplot(111, projection='3d')
def plot_trajectories(self, solver: ThreeBodySolver,
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹"):
"""
绘制三体运动轨迹
参数:
solver: 三体问题求解器
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if self.ax is None:
self.create_3d_plot()
trajectories = solver.get_trajectories()
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
self.ax.scatter(com[0], com[1], com[2],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
self.ax.set_xlabel('X (AU)', fontsize=12)
self.ax.set_ylabel('Y (AU)', fontsize=12)
self.ax.set_zlabel('Z (AU)', fontsize=12)
self.ax.set_title(title, fontsize=14, fontweight='bold')
self.ax.legend(fontsize=10)
self.ax.grid(True, alpha=0.3)
# 设置等比例坐标轴
self.ax.set_aspect('auto')
def plot_2d_projection(self, solver: ThreeBodySolver,
projection: str = 'xy',
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹 (2D投影)"):
"""
绘制2D投影
参数:
solver: 三体问题求解器
projection: 投影平面 ('xy', 'xz', 'yz')
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if projection not in ['xy', 'xz', 'yz']:
raise ValueError("projection 必须是 'xy', 'xz''yz'")
fig, ax = plt.subplots(figsize=(10, 8))
trajectories = solver.get_trajectories()
# 确定坐标轴索引
if projection == 'xy':
x_idx, y_idx = 0, 1
x_label, y_label = 'X (AU)', 'Y (AU)'
elif projection == 'xz':
x_idx, y_idx = 0, 2
x_label, y_label = 'X (AU)', 'Z (AU)'
else: # 'yz'
x_idx, y_idx = 1, 2
x_label, y_label = 'Y (AU)', 'Z (AU)'
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
ax.plot(traj[:, x_idx], traj[:, y_idx],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
ax.scatter(current_pos[x_idx], current_pos[y_idx],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
ax.scatter(com[x_idx], com[y_idx],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
ax.set_xlabel(x_label, fontsize=12)
ax.set_ylabel(y_label, fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
return fig, ax
def plot_energy_conservation(self, solver: ThreeBodySolver,
time_points: Optional[np.ndarray] = None):
"""
绘制能量守恒图
参数:
solver: 三体问题求解器
time_points: 时间点数组如果为None则使用模拟时间
"""
# 从历史记录中提取能量信息
# 注意:这里需要修改求解器以记录能量历史
# 暂时返回一个简单的图形
fig, ax = plt.subplots(figsize=(10, 6))
# 这里可以添加能量历史记录功能
ax.set_xlabel('时间 (年)', fontsize=12)
ax.set_ylabel('总能量', fontsize=12)
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def plot_phase_space(self, solver: ThreeBodySolver,
particle_index: int = 0,
dimension: str = 'x'):
"""
绘制相空间图(位置 vs 速度)
参数:
solver: 三体问题求解器
particle_index: 质点索引 (0, 1, 2)
dimension: 维度 ('x', 'y', 'z')
"""
if particle_index not in [0, 1, 2]:
raise ValueError("particle_index 必须是 0, 1, 或 2")
if dimension not in ['x', 'y', 'z']:
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
particle = solver.particles[particle_index]
trajectories = particle.get_trajectory()
if len(trajectories) < 2:
print("警告:轨迹数据不足,无法绘制相空间图")
return None, None
# 获取位置和速度历史
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
# 注意:速度历史需要从求解器中获取
# 暂时使用位置差分近似速度
if len(positions) > 1:
dt = solver.dt
velocities = np.gradient(positions, dt)
else:
velocities = np.zeros_like(positions)
fig, ax = plt.subplots(figsize=(10, 6))
# 绘制相空间轨迹
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
cmap='viridis', alpha=0.7, s=20)
# 添加颜色条表示时间
plt.colorbar(scatter, ax=ax, label='时间步')
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def animate_trajectories(self, solver: ThreeBodySolver,
save_path: Optional[str] = None,
fps: int = 30,
dpi: int = 100):
"""
创建轨迹动画(需要额外安装 matplotlib.animation
参数:
solver: 三体问题求解器
save_path: 保存路径如果为None则显示动画
fps: 帧率
dpi: 分辨率
"""
try:
from matplotlib.animation import FuncAnimation
except ImportError:
print("错误:需要安装 matplotlib.animation 模块")
return
trajectories = solver.get_trajectories()
if any(len(traj) < 2 for traj in trajectories):
print("错误:轨迹数据不足,无法创建动画")
return
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# 初始化图形元素
lines = []
points = []
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 创建轨迹线
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
lines.append(line)
# 创建当前位置点
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
markeredgecolor='black', markeredgewidth=1)
points.append(point)
# 设置坐标轴范围
all_points = np.vstack(trajectories)
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
ax.set_xlabel('X (AU)', fontsize=12)
ax.set_ylabel('Y (AU)', fontsize=12)
ax.set_zlabel('Z (AU)', fontsize=12)
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
# 动画更新函数
def update(frame):
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
# 更新轨迹线(显示到当前帧)
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
line.set_3d_properties(traj[:frame+1, 2])
# 更新当前位置点
if frame < len(traj):
point.set_data([traj[frame, 0]], [traj[frame, 1]])
point.set_3d_properties([traj[frame, 2]])
return lines + points
# 创建动画
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
interval=1000/fps, blit=True)
if save_path:
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
print(f"动画已保存到: {save_path}")
else:
plt.show()
return anim
def show(self):
"""显示所有图形"""
plt.tight_layout()
plt.show()
def save_figure(self, filename: str, dpi: int = 300):
"""保存当前图形"""
if self.fig is not None:
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
print(f"图形已保存到: {filename}")
else:
print("警告:没有活动的图形可以保存")