Files
dison0331-ThinkPad 8c8ad9fe07 first
pc-1
2026-03-11 21:32:58 +08:00

321 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
三体问题可视化模块
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import List, Optional, Tuple
from .particle import Particle
from .solver import ThreeBodySolver
class ThreeBodyVisualizer:
"""三体问题可视化类"""
def __init__(self, figsize: Tuple[int, int] = (12, 10)):
"""
初始化可视化器
参数:
figsize: 图形大小
"""
self.figsize = figsize
self.fig = None
self.ax = None
def create_3d_plot(self):
"""创建3D图形"""
self.fig = plt.figure(figsize=self.figsize)
self.ax = self.fig.add_subplot(111, projection='3d')
def plot_trajectories(self, solver: ThreeBodySolver,
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹"):
"""
绘制三体运动轨迹
参数:
solver: 三体问题求解器
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if self.ax is None:
self.create_3d_plot()
trajectories = solver.get_trajectories()
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
self.ax.scatter(current_pos[0], current_pos[1], current_pos[2],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
self.ax.scatter(com[0], com[1], com[2],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
self.ax.set_xlabel('X (AU)', fontsize=12)
self.ax.set_ylabel('Y (AU)', fontsize=12)
self.ax.set_zlabel('Z (AU)', fontsize=12)
self.ax.set_title(title, fontsize=14, fontweight='bold')
self.ax.legend(fontsize=10)
self.ax.grid(True, alpha=0.3)
# 设置等比例坐标轴
self.ax.set_aspect('auto')
def plot_2d_projection(self, solver: ThreeBodySolver,
projection: str = 'xy',
show_current_positions: bool = True,
show_com: bool = True,
title: str = "三体问题轨迹 (2D投影)"):
"""
绘制2D投影
参数:
solver: 三体问题求解器
projection: 投影平面 ('xy', 'xz', 'yz')
show_current_positions: 是否显示当前位置
show_com: 是否显示质心
title: 图形标题
"""
if projection not in ['xy', 'xz', 'yz']:
raise ValueError("projection 必须是 'xy', 'xz''yz'")
fig, ax = plt.subplots(figsize=(10, 8))
trajectories = solver.get_trajectories()
# 确定坐标轴索引
if projection == 'xy':
x_idx, y_idx = 0, 1
x_label, y_label = 'X (AU)', 'Y (AU)'
elif projection == 'xz':
x_idx, y_idx = 0, 2
x_label, y_label = 'X (AU)', 'Z (AU)'
else: # 'yz'
x_idx, y_idx = 1, 2
x_label, y_label = 'Y (AU)', 'Z (AU)'
# 绘制每个质点的轨迹
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 绘制轨迹线
ax.plot(traj[:, x_idx], traj[:, y_idx],
color=color, alpha=0.7, linewidth=1.5, label=label)
# 绘制当前位置
if show_current_positions and len(traj) > 0:
current_pos = traj[-1]
ax.scatter(current_pos[x_idx], current_pos[y_idx],
color=color, s=100, edgecolors='black', linewidth=1.5)
# 绘制质心
if show_com:
com = solver.get_center_of_mass()
ax.scatter(com[x_idx], com[y_idx],
color='black', marker='x', s=200, label='质心', linewidth=2)
# 设置图形属性
ax.set_xlabel(x_label, fontsize=12)
ax.set_ylabel(y_label, fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')
return fig, ax
def plot_energy_conservation(self, solver: ThreeBodySolver,
time_points: Optional[np.ndarray] = None):
"""
绘制能量守恒图
参数:
solver: 三体问题求解器
time_points: 时间点数组如果为None则使用模拟时间
"""
# 从历史记录中提取能量信息
# 注意:这里需要修改求解器以记录能量历史
# 暂时返回一个简单的图形
fig, ax = plt.subplots(figsize=(10, 6))
# 这里可以添加能量历史记录功能
ax.set_xlabel('时间 (年)', fontsize=12)
ax.set_ylabel('总能量', fontsize=12)
ax.set_title('能量守恒', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def plot_phase_space(self, solver: ThreeBodySolver,
particle_index: int = 0,
dimension: str = 'x'):
"""
绘制相空间图(位置 vs 速度)
参数:
solver: 三体问题求解器
particle_index: 质点索引 (0, 1, 2)
dimension: 维度 ('x', 'y', 'z')
"""
if particle_index not in [0, 1, 2]:
raise ValueError("particle_index 必须是 0, 1, 或 2")
if dimension not in ['x', 'y', 'z']:
raise ValueError("dimension 必须是 'x', 'y', 或 'z'")
particle = solver.particles[particle_index]
trajectories = particle.get_trajectory()
if len(trajectories) < 2:
print("警告:轨迹数据不足,无法绘制相空间图")
return None, None
# 获取位置和速度历史
positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]]
# 注意:速度历史需要从求解器中获取
# 暂时使用位置差分近似速度
if len(positions) > 1:
dt = solver.dt
velocities = np.gradient(positions, dt)
else:
velocities = np.zeros_like(positions)
fig, ax = plt.subplots(figsize=(10, 6))
# 绘制相空间轨迹
scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)),
cmap='viridis', alpha=0.7, s=20)
# 添加颜色条表示时间
plt.colorbar(scatter, ax=ax, label='时间步')
ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12)
ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12)
ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)',
fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
return fig, ax
def animate_trajectories(self, solver: ThreeBodySolver,
save_path: Optional[str] = None,
fps: int = 30,
dpi: int = 100):
"""
创建轨迹动画(需要额外安装 matplotlib.animation
参数:
solver: 三体问题求解器
save_path: 保存路径如果为None则显示动画
fps: 帧率
dpi: 分辨率
"""
try:
from matplotlib.animation import FuncAnimation
except ImportError:
print("错误:需要安装 matplotlib.animation 模块")
return
trajectories = solver.get_trajectories()
if any(len(traj) < 2 for traj in trajectories):
print("错误:轨迹数据不足,无法创建动画")
return
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# 初始化图形元素
lines = []
points = []
colors = ['red', 'green', 'blue']
for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)):
color = particle.color if particle.color else colors[i % len(colors)]
label = particle.name if particle.name else f"质点 {i+1}"
# 创建轨迹线
line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label)
lines.append(line)
# 创建当前位置点
point, = ax.plot([], [], [], 'o', color=color, markersize=8,
markeredgecolor='black', markeredgewidth=1)
points.append(point)
# 设置坐标轴范围
all_points = np.vstack(trajectories)
max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)])
mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5
ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2)
ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2)
ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2)
ax.set_xlabel('X (AU)', fontsize=12)
ax.set_ylabel('Y (AU)', fontsize=12)
ax.set_zlabel('Z (AU)', fontsize=12)
ax.set_title('三体问题动画', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
# 动画更新函数
def update(frame):
for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)):
# 更新轨迹线(显示到当前帧)
line.set_data(traj[:frame+1, 0], traj[:frame+1, 1])
line.set_3d_properties(traj[:frame+1, 2])
# 更新当前位置点
if frame < len(traj):
point.set_data([traj[frame, 0]], [traj[frame, 1]])
point.set_3d_properties([traj[frame, 2]])
return lines + points
# 创建动画
anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories),
interval=1000/fps, blit=True)
if save_path:
anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi)
print(f"动画已保存到: {save_path}")
else:
plt.show()
return anim
def show(self):
"""显示所有图形"""
plt.tight_layout()
plt.show()
def save_figure(self, filename: str, dpi: int = 300):
"""保存当前图形"""
if self.fig is not None:
self.fig.savefig(filename, dpi=dpi, bbox_inches='tight')
print(f"图形已保存到: {filename}")
else:
print("警告:没有活动的图形可以保存")