Source code for diffinytrace.plotting.system2D

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.

__all__ = [
    "annotate_position_simple",
    "annotate_position",
    "annotated_arrow",
    "layout",
    "ray_paths",
    "_plot_surface",
    "_plot_surface_recursively",
    "plot"
]

import torch
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from copy import deepcopy

[docs] def annotate_position_simple(nz,ny,name): """ Annotate the position of a point in 2D space using its coordinates. Args: nz (torch.Tensor): z-coordinates of the point. ny (torch.Tensor): y-coordinates of the point. name (str): Text label to annotate at the position. Returns: None """ zdiff = (torch.max(nz)-torch.min(nz)) ydiff = (torch.max(ny)-torch.min(ny)) offset = max(ydiff*0.05,zdiff*0.025) argmax = torch.argmax(ny) zpos = nz[argmax]#torch.min(nz)+zdiff*0.5 ypos = ny[argmax] fontsize = 10 #-len(name)*fontsize/4.0 plt.annotate(name,xy=(zpos, ypos),fontsize=fontsize,xytext=(0.0,fontsize*0.5), textcoords='offset points')
[docs] def annotate_position(position,offset,name,color="black",**kwargs): """ Annotate a point in 2D space with an arrow and label. Args: position (tuple): (z, y) coordinates of the point. offset (tuple): Offset for the annotation text. name (str): Text label to annotate. color (str): Color of the annotation and arrow. Returns: None """ plt.annotate(name,color=color,xy=position,xytext=offset, textcoords='offset points',arrowprops=dict(arrowstyle="->",color=color,linewidth=1.5, mutation_scale=10), **kwargs)
[docs] def annotated_arrow(start,end,offset,name,arrowstyle,color="black",**kwargs): """ Draw and annotate an arrow between two points in 2D space. Args: start (tuple): Start position (z, y) of the arrow. end (tuple): End position (z, y) of the arrow. offset (tuple): Offset for the annotation text. name (str): Text label to annotate. arrowstyle (str): Matplotlib arrow style string. color (str): Color of the arrow and annotation. Returns: None """ arrow_patch = patches.FancyArrowPatch(start, end, arrowstyle=arrowstyle,linewidth=1.5, mutation_scale=10,color=color) plt.gca().add_patch(arrow_patch) middle = (start[0]+ (end[0]-start[0])*0.5,start[1]+ (end[1]-start[1])*0.5) plt.annotate(name,xy=middle,xytext=offset, textcoords='offset points',color=color,**kwargs)
[docs] def layout(): """ Set up the layout for the 2D plot, including margins, aspect ratio, and axis labels. Returns: None """ #plt.grid(True) plt.margins(x=0.1,y=0.1) plt.gca().set_aspect('equal') plt.ylabel("y [mm]") plt.xlabel("z [mm]")
[docs] def ray_paths(rays,ray_color="#85549c",ray_linewidth=1.25): """ Plot ray paths projected onto the y-z plane. Args: rays (list[torch.Tensor]): List of ray paths to plot. ray_color (str): Color of the rays. ray_linewidth (float): Line width of the rays. Returns: None """ ray_color = mcolors.to_hex(ray_color) print("WARNING: ray_paths will project the ray position onto the y-z plane!") pathsA = rays if torch.is_tensor(rays[0]): pathsA = np.array([elem.numpy() for elem in rays]) pathsA = np.array(pathsA) for iray in range(pathsA.shape[1]): plt.plot(pathsA[:,iray,2],pathsA[:,iray,1],color=ray_color,linewidth=ray_linewidth)
def _plot_surface(surface,name,resolution,annotate,fill_color,outline_color,linewidth): """ Plot a 2D surface and optionally annotate it. Args: surface: Object with get_plot_points_2D method. name (str): Name for annotation. resolution (int): Resolution for the surface plot. annotate (bool): Whether to annotate the surface. fill_color (str): Fill color for the surface. outline_color (str): Outline color for the surface. linewidth (float): Line width for the surface. Returns: None """ surface_list = surface.get_plot_points_2D(resolution) if len(surface_list)==0: return if fill_color is None: fill_color = surface.fill_color if outline_color is None: outline_color = surface.outline_color zs,ys = torch.cat([z for z,y in surface_list]),torch.cat([y for z,y in surface_list]) if annotate: annotate_position_simple(zs,ys,name) if surface.is_volume: ax = plt.gca() ax.fill(zs, ys, facecolor=fill_color, edgecolor=outline_color, linewidth=linewidth) else: for z,y in surface_list: plt.plot(z,y,color=outline_color,label="",linewidth=linewidth) def _plot_surface_recursively(current_elem,name,resolution=200,annotate=False,fill_color=None,outline_color=None,linewidth=None): """ Recursively plot a surface and its plotable children in 2D. Args: current_elem: The current plotable element. name (str): Name for annotation. resolution (int): Resolution for the surface plot. annotate (bool): Whether to annotate the surface. fill_color (str): Fill color for the surface. outline_color (str): Outline color for the surface. linewidth (float): Line width for the surface. Returns: None """ _plot_surface(current_elem,name,resolution,annotate,fill_color,outline_color,linewidth) for elem,elem_name in current_elem.get_plotable_childs(): _plot_surface_recursively(elem,elem_name,resolution,annotate,fill_color,outline_color,linewidth)
[docs] def plot(element=None,rays=None,resolution=200,annotate=False,ray_color="#85549c",ray_linewidth=1.25,fill_color=None,outline_color=None,linewidth=None,show=True): """ Plot a 2D surface and optionally ray paths. Args: element: The element to plot (must implement Plotable interface). rays (list[torch.Tensor]): List of ray paths to plot. resolution (int): Resolution for the surface plot. annotate (bool): Whether to annotate the surface. ray_color (str): Color of the rays. ray_linewidth (float): Line width of the rays. fill_color (str): Fill color for the surface. outline_color (str): Outline color for the surface. linewidth (float): Line width for the surface. show (bool): Whether to display the plot immediately. Returns: None """ layout() if isinstance(element,(list,tuple)): for subelem in element: subelem = deepcopy(subelem) subelem = subelem.to("cpu") _plot_surface_recursively(subelem,"",resolution,annotate,fill_color,outline_color,linewidth) elif not element is None: element = deepcopy(element) element = element.to("cpu") _plot_surface_recursively(element,"",resolution,annotate,fill_color,outline_color,linewidth) if not rays is None: if isinstance(rays,dict): rays = rays["ray_paths"] if torch.is_tensor(rays[0]): rays = [elem.cpu() for elem in rays] ray_paths(rays,ray_color=ray_color,ray_linewidth=ray_linewidth) if show: plt.show()