first
pc-1
This commit is contained in:
321
three_body_problem/visualizer.py
Normal file
321
three_body_problem/visualizer.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
三体问题可视化模块
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from typing import List, Optional, Tuple
|
||||
from .particle import Particle
|
||||
from .solver import ThreeBodySolver
|
||||
|
||||
|
||||
class ThreeBodyVisualizer:
|
||||
"""三体问题可视化类"""
|
||||
|
||||
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
|
||||
"""
|
||||
初始化可视化器
|
||||
|
||||
参数:
|
||||
figsize: 图形大小
|
||||
"""
|
||||
self.figsize = figsize
|
||||
self.fig = None
|
||||
self.ax = None
|
||||
|
||||
def create_3d_plot(self):
|
||||
"""创建3D图形"""
|
||||
self.fig = plt.figure(figsize=self.figsize)
|
||||
self.ax = self.fig.add_subplot(111, projection='3d')
|
||||
|
||||
def plot_trajectories(self, solver: ThreeBodySolver,
|
||||
show_current_positions: bool = True,
|
||||
show_com: bool = True,
|
||||
title: str = "三体问题轨迹"):
|
||||
"""
|
||||
绘制三体运动轨迹
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
show_current_positions: 是否显示当前位置
|
||||
show_com: 是否显示质心
|
||||
title: 图形标题
|
||||
"""
|
||||
if self.ax is None:
|
||||
self.create_3d_plot()
|
||||
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 绘制每个质点的轨迹
|
||||
colors = ['red', 'green', 'blue']
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 绘制轨迹线
|
||||
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制当前位置
|
||||
if show_current_positions and len(traj) > 0:
|
||||
current_pos = traj[-1]
|
||||
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心
|
||||
if show_com:
|
||||
com = solver.get_center_of_mass()
|
||||
self.ax.scatter(com[0], com[1], com[2],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
# 设置图形属性
|
||||
self.ax.set_xlabel('X (AU)', fontsize=12)
|
||||
self.ax.set_ylabel('Y (AU)', fontsize=12)
|
||||
self.ax.set_zlabel('Z (AU)', fontsize=12)
|
||||
self.ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
self.ax.legend(fontsize=10)
|
||||
self.ax.grid(True, alpha=0.3)
|
||||
|
||||
# 设置等比例坐标轴
|
||||
self.ax.set_aspect('auto')
|
||||
|
||||
def plot_2d_projection(self, solver: ThreeBodySolver,
|
||||
projection: str = 'xy',
|
||||
show_current_positions: bool = True,
|
||||
show_com: bool = True,
|
||||
title: str = "三体问题轨迹 (2D投影)"):
|
||||
"""
|
||||
绘制2D投影
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
projection: 投影平面 ('xy', 'xz', 'yz')
|
||||
show_current_positions: 是否显示当前位置
|
||||
show_com: 是否显示质心
|
||||
title: 图形标题
|
||||
"""
|
||||
if projection not in ['xy', 'xz', 'yz']:
|
||||
raise ValueError("projection 必须是 'xy', 'xz' 或 'yz'")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
trajectories = solver.get_trajectories()
|
||||
|
||||
# 确定坐标轴索引
|
||||
if projection == 'xy':
|
||||
x_idx, y_idx = 0, 1
|
||||
x_label, y_label = 'X (AU)', 'Y (AU)'
|
||||
elif projection == 'xz':
|
||||
x_idx, y_idx = 0, 2
|
||||
x_label, y_label = 'X (AU)', 'Z (AU)'
|
||||
else: # 'yz'
|
||||
x_idx, y_idx = 1, 2
|
||||
x_label, y_label = 'Y (AU)', 'Z (AU)'
|
||||
|
||||
# 绘制每个质点的轨迹
|
||||
colors = ['red', 'green', 'blue']
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 绘制轨迹线
|
||||
ax.plot(traj[:, x_idx], traj[:, y_idx],
|
||||
color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
|
||||
# 绘制当前位置
|
||||
if show_current_positions and len(traj) > 0:
|
||||
current_pos = traj[-1]
|
||||
ax.scatter(current_pos[x_idx], current_pos[y_idx],
|
||||
color=color, s=100, edgecolors='black', linewidth=1.5)
|
||||
|
||||
# 绘制质心
|
||||
if show_com:
|
||||
com = solver.get_center_of_mass()
|
||||
ax.scatter(com[x_idx], com[y_idx],
|
||||
color='black', marker='x', s=200, label='质心', linewidth=2)
|
||||
|
||||
# 设置图形属性
|
||||
ax.set_xlabel(x_label, fontsize=12)
|
||||
ax.set_ylabel(y_label, fontsize=12)
|
||||
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_aspect('equal', adjustable='box')
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_energy_conservation(self, solver: ThreeBodySolver,
|
||||
time_points: Optional[np.ndarray] = None):
|
||||
"""
|
||||
绘制能量守恒图
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
time_points: 时间点数组(如果为None,则使用模拟时间)
|
||||
"""
|
||||
# 从历史记录中提取能量信息
|
||||
# 注意:这里需要修改求解器以记录能量历史
|
||||
# 暂时返回一个简单的图形
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# 这里可以添加能量历史记录功能
|
||||
ax.set_xlabel('时间 (年)', fontsize=12)
|
||||
ax.set_ylabel('总能量', fontsize=12)
|
||||
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_phase_space(self, solver: ThreeBodySolver,
|
||||
particle_index: int = 0,
|
||||
dimension: str = 'x'):
|
||||
"""
|
||||
绘制相空间图(位置 vs 速度)
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
particle_index: 质点索引 (0, 1, 2)
|
||||
dimension: 维度 ('x', 'y', 'z')
|
||||
"""
|
||||
if particle_index not in [0, 1, 2]:
|
||||
raise ValueError("particle_index 必须是 0, 1, 或 2")
|
||||
|
||||
if dimension not in ['x', 'y', 'z']:
|
||||
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
|
||||
|
||||
particle = solver.particles[particle_index]
|
||||
trajectories = particle.get_trajectory()
|
||||
|
||||
if len(trajectories) < 2:
|
||||
print("警告:轨迹数据不足,无法绘制相空间图")
|
||||
return None, None
|
||||
|
||||
# 获取位置和速度历史
|
||||
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
|
||||
|
||||
# 注意:速度历史需要从求解器中获取
|
||||
# 暂时使用位置差分近似速度
|
||||
if len(positions) > 1:
|
||||
dt = solver.dt
|
||||
velocities = np.gradient(positions, dt)
|
||||
else:
|
||||
velocities = np.zeros_like(positions)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# 绘制相空间轨迹
|
||||
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
|
||||
cmap='viridis', alpha=0.7, s=20)
|
||||
|
||||
# 添加颜色条表示时间
|
||||
plt.colorbar(scatter, ax=ax, label='时间步')
|
||||
|
||||
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
|
||||
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
|
||||
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
return fig, ax
|
||||
|
||||
def animate_trajectories(self, solver: ThreeBodySolver,
|
||||
save_path: Optional[str] = None,
|
||||
fps: int = 30,
|
||||
dpi: int = 100):
|
||||
"""
|
||||
创建轨迹动画(需要额外安装 matplotlib.animation)
|
||||
|
||||
参数:
|
||||
solver: 三体问题求解器
|
||||
save_path: 保存路径(如果为None则显示动画)
|
||||
fps: 帧率
|
||||
dpi: 分辨率
|
||||
"""
|
||||
try:
|
||||
from matplotlib.animation import FuncAnimation
|
||||
except ImportError:
|
||||
print("错误:需要安装 matplotlib.animation 模块")
|
||||
return
|
||||
|
||||
trajectories = solver.get_trajectories()
|
||||
if any(len(traj) < 2 for traj in trajectories):
|
||||
print("错误:轨迹数据不足,无法创建动画")
|
||||
return
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# 初始化图形元素
|
||||
lines = []
|
||||
points = []
|
||||
colors = ['red', 'green', 'blue']
|
||||
|
||||
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
|
||||
color = particle.color if particle.color else colors[i % len(colors)]
|
||||
label = particle.name if particle.name else f"质点 {i+1}"
|
||||
|
||||
# 创建轨迹线
|
||||
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
|
||||
lines.append(line)
|
||||
|
||||
# 创建当前位置点
|
||||
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
|
||||
markeredgecolor='black', markeredgewidth=1)
|
||||
points.append(point)
|
||||
|
||||
# 设置坐标轴范围
|
||||
all_points = np.vstack(trajectories)
|
||||
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
|
||||
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
|
||||
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
|
||||
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
|
||||
|
||||
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
|
||||
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
|
||||
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
|
||||
|
||||
ax.set_xlabel('X (AU)', fontsize=12)
|
||||
ax.set_ylabel('Y (AU)', fontsize=12)
|
||||
ax.set_zlabel('Z (AU)', fontsize=12)
|
||||
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# 动画更新函数
|
||||
def update(frame):
|
||||
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
|
||||
# 更新轨迹线(显示到当前帧)
|
||||
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
|
||||
line.set_3d_properties(traj[:frame+1, 2])
|
||||
|
||||
# 更新当前位置点
|
||||
if frame < len(traj):
|
||||
point.set_data([traj[frame, 0]], [traj[frame, 1]])
|
||||
point.set_3d_properties([traj[frame, 2]])
|
||||
|
||||
return lines + points
|
||||
|
||||
# 创建动画
|
||||
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
|
||||
interval=1000/fps, blit=True)
|
||||
|
||||
if save_path:
|
||||
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
|
||||
print(f"动画已保存到: {save_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
return anim
|
||||
|
||||
def show(self):
|
||||
"""显示所有图形"""
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
def save_figure(self, filename: str, dpi: int = 300):
|
||||
"""保存当前图形"""
|
||||
if self.fig is not None:
|
||||
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
|
||||
print(f"图形已保存到: {filename}")
|
||||
else:
|
||||
print("警告:没有活动的图形可以保存")
|
||||
Reference in New Issue
Block a user