""" 三体问题可视化模块 """ import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from typing import List, Optional, Tuple from .particle import Particle from .solver import ThreeBodySolver class ThreeBodyVisualizer: """三体问题可视化类""" def __init__(self, figsize: Tuple[int, int] = (12, 10)): """ 初始化可视化器 参数: figsize: 图形大小 """ self.figsize = figsize self.fig = None self.ax = None def create_3d_plot(self): """创建3D图形""" self.fig = plt.figure(figsize=self.figsize) self.ax = self.fig.add_subplot(111, projection='3d') def plot_trajectories(self, solver: ThreeBodySolver, show_current_positions: bool = True, show_com: bool = True, title: str = "三体问题轨迹"): """ 绘制三体运动轨迹 参数: solver: 三体问题求解器 show_current_positions: 是否显示当前位置 show_com: 是否显示质心 title: 图形标题 """ if self.ax is None: self.create_3d_plot() trajectories = solver.get_trajectories() # 绘制每个质点的轨迹 colors = ['red', 'green', 'blue'] for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)): color = particle.color if particle.color else colors[i % len(colors)] label = particle.name if particle.name else f"质点 {i+1}" # 绘制轨迹线 self.ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], color=color, alpha=0.7, linewidth=1.5, label=label) # 绘制当前位置 if show_current_positions and len(traj) > 0: current_pos = traj[-1] self.ax.scatter(current_pos[0], current_pos[1], current_pos[2], color=color, s=100, edgecolors='black', linewidth=1.5) # 绘制质心 if show_com: com = solver.get_center_of_mass() self.ax.scatter(com[0], com[1], com[2], color='black', marker='x', s=200, label='质心', linewidth=2) # 设置图形属性 self.ax.set_xlabel('X (AU)', fontsize=12) self.ax.set_ylabel('Y (AU)', fontsize=12) self.ax.set_zlabel('Z (AU)', fontsize=12) self.ax.set_title(title, fontsize=14, fontweight='bold') self.ax.legend(fontsize=10) self.ax.grid(True, alpha=0.3) # 设置等比例坐标轴 self.ax.set_aspect('auto') def plot_2d_projection(self, solver: ThreeBodySolver, projection: str = 'xy', show_current_positions: bool = True, show_com: bool = True, title: str = "三体问题轨迹 (2D投影)"): """ 绘制2D投影 参数: solver: 三体问题求解器 projection: 投影平面 ('xy', 'xz', 'yz') show_current_positions: 是否显示当前位置 show_com: 是否显示质心 title: 图形标题 """ if projection not in ['xy', 'xz', 'yz']: raise ValueError("projection 必须是 'xy', 'xz' 或 'yz'") fig, ax = plt.subplots(figsize=(10, 8)) trajectories = solver.get_trajectories() # 确定坐标轴索引 if projection == 'xy': x_idx, y_idx = 0, 1 x_label, y_label = 'X (AU)', 'Y (AU)' elif projection == 'xz': x_idx, y_idx = 0, 2 x_label, y_label = 'X (AU)', 'Z (AU)' else: # 'yz' x_idx, y_idx = 1, 2 x_label, y_label = 'Y (AU)', 'Z (AU)' # 绘制每个质点的轨迹 colors = ['red', 'green', 'blue'] for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)): color = particle.color if particle.color else colors[i % len(colors)] label = particle.name if particle.name else f"质点 {i+1}" # 绘制轨迹线 ax.plot(traj[:, x_idx], traj[:, y_idx], color=color, alpha=0.7, linewidth=1.5, label=label) # 绘制当前位置 if show_current_positions and len(traj) > 0: current_pos = traj[-1] ax.scatter(current_pos[x_idx], current_pos[y_idx], color=color, s=100, edgecolors='black', linewidth=1.5) # 绘制质心 if show_com: com = solver.get_center_of_mass() ax.scatter(com[x_idx], com[y_idx], color='black', marker='x', s=200, label='质心', linewidth=2) # 设置图形属性 ax.set_xlabel(x_label, fontsize=12) ax.set_ylabel(y_label, fontsize=12) ax.set_title(title, fontsize=14, fontweight='bold') ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.set_aspect('equal', adjustable='box') return fig, ax def plot_energy_conservation(self, solver: ThreeBodySolver, time_points: Optional[np.ndarray] = None): """ 绘制能量守恒图 参数: solver: 三体问题求解器 time_points: 时间点数组(如果为None,则使用模拟时间) """ # 从历史记录中提取能量信息 # 注意:这里需要修改求解器以记录能量历史 # 暂时返回一个简单的图形 fig, ax = plt.subplots(figsize=(10, 6)) # 这里可以添加能量历史记录功能 ax.set_xlabel('时间 (年)', fontsize=12) ax.set_ylabel('总能量', fontsize=12) ax.set_title('能量守恒', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) return fig, ax def plot_phase_space(self, solver: ThreeBodySolver, particle_index: int = 0, dimension: str = 'x'): """ 绘制相空间图(位置 vs 速度) 参数: solver: 三体问题求解器 particle_index: 质点索引 (0, 1, 2) dimension: 维度 ('x', 'y', 'z') """ if particle_index not in [0, 1, 2]: raise ValueError("particle_index 必须是 0, 1, 或 2") if dimension not in ['x', 'y', 'z']: raise ValueError("dimension 必须是 'x', 'y', 或 'z'") particle = solver.particles[particle_index] trajectories = particle.get_trajectory() if len(trajectories) < 2: print("警告:轨迹数据不足,无法绘制相空间图") return None, None # 获取位置和速度历史 positions = trajectories[:, {'x': 0, 'y': 1, 'z': 2}[dimension]] # 注意:速度历史需要从求解器中获取 # 暂时使用位置差分近似速度 if len(positions) > 1: dt = solver.dt velocities = np.gradient(positions, dt) else: velocities = np.zeros_like(positions) fig, ax = plt.subplots(figsize=(10, 6)) # 绘制相空间轨迹 scatter = ax.scatter(positions, velocities, c=np.arange(len(positions)), cmap='viridis', alpha=0.7, s=20) # 添加颜色条表示时间 plt.colorbar(scatter, ax=ax, label='时间步') ax.set_xlabel(f'{dimension.upper()} 位置 (AU)', fontsize=12) ax.set_ylabel(f'{dimension.upper()} 速度 (AU/年)', fontsize=12) ax.set_title(f'质点 {particle_index+1} 的相空间图 ({dimension.upper()}维度)', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) return fig, ax def animate_trajectories(self, solver: ThreeBodySolver, save_path: Optional[str] = None, fps: int = 30, dpi: int = 100): """ 创建轨迹动画(需要额外安装 matplotlib.animation) 参数: solver: 三体问题求解器 save_path: 保存路径(如果为None则显示动画) fps: 帧率 dpi: 分辨率 """ try: from matplotlib.animation import FuncAnimation except ImportError: print("错误:需要安装 matplotlib.animation 模块") return trajectories = solver.get_trajectories() if any(len(traj) < 2 for traj in trajectories): print("错误:轨迹数据不足,无法创建动画") return fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection='3d') # 初始化图形元素 lines = [] points = [] colors = ['red', 'green', 'blue'] for i, (traj, particle) in enumerate(zip(trajectories, solver.particles)): color = particle.color if particle.color else colors[i % len(colors)] label = particle.name if particle.name else f"质点 {i+1}" # 创建轨迹线 line, = ax.plot([], [], [], color=color, alpha=0.7, linewidth=1.5, label=label) lines.append(line) # 创建当前位置点 point, = ax.plot([], [], [], 'o', color=color, markersize=8, markeredgecolor='black', markeredgewidth=1) points.append(point) # 设置坐标轴范围 all_points = np.vstack(trajectories) max_range = np.max([all_points[:, i].max() - all_points[:, i].min() for i in range(3)]) mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5 mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5 mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5 ax.set_xlim(mid_x - max_range/2, mid_x + max_range/2) ax.set_ylim(mid_y - max_range/2, mid_y + max_range/2) ax.set_zlim(mid_z - max_range/2, mid_z + max_range/2) ax.set_xlabel('X (AU)', fontsize=12) ax.set_ylabel('Y (AU)', fontsize=12) ax.set_zlabel('Z (AU)', fontsize=12) ax.set_title('三体问题动画', fontsize=14, fontweight='bold') ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # 动画更新函数 def update(frame): for i, (traj, line, point) in enumerate(zip(trajectories, lines, points)): # 更新轨迹线(显示到当前帧) line.set_data(traj[:frame+1, 0], traj[:frame+1, 1]) line.set_3d_properties(traj[:frame+1, 2]) # 更新当前位置点 if frame < len(traj): point.set_data([traj[frame, 0]], [traj[frame, 1]]) point.set_3d_properties([traj[frame, 2]]) return lines + points # 创建动画 anim = FuncAnimation(fig, update, frames=min(len(traj) for traj in trajectories), interval=1000/fps, blit=True) if save_path: anim.save(save_path, writer='ffmpeg', fps=fps, dpi=dpi) print(f"动画已保存到: {save_path}") else: plt.show() return anim def show(self): """显示所有图形""" plt.tight_layout() plt.show() def save_figure(self, filename: str, dpi: int = 300): """保存当前图形""" if self.fig is not None: self.fig.savefig(filename, dpi=dpi, bbox_inches='tight') print(f"图形已保存到: {filename}") else: print("警告:没有活动的图形可以保存")