# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.
__all__ = [
"ray_paths_one_bin",
"ray_paths",
"surface",
"get_optical_system_layout",
"_plot_surface",
"_plot_surface_recursively",
"plot"
]
import pandas as pd
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
#import matplotlib.pyplot as plt
#from PIL import Image
#import tempfile
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import matplotlib.colors as mcolors
import plotly.io as pio
import copy
##2013FF
[docs]
def ray_paths_one_bin(rays,ray_color,ray_linewidth):
"""
Generate a Plotly 3D line plot for a group of rays with the same number of segments.
Args:
rays (list[torch.Tensor]): List of ray paths, each as a tensor.
ray_color (str): Color for the ray lines.
ray_linewidth (float): Line width for the ray lines.
Returns:
plotly.graph_objs.Figure: Plotly figure containing the ray paths.
"""
rays = [elem.numpy() for elem in rays]
rays = np.array(rays)
rays = torch.tensor(rays)
x = rays[:,:,0].reshape(-1)
y = rays[:,:,1].reshape(-1)
z = rays[:,:,2].reshape(-1)
ray_id = torch.arange(rays.shape[1]).reshape(-1,1).repeat(1,rays.shape[0]).T.reshape(-1)
df = pd.DataFrame({"X":x,"Y":y,"Z":z,"ray id":ray_id})
line_fig = px.line_3d(df, x='X', y='Y', z='Z', line_group="ray id")
for k in range(len(line_fig.data)):
line_fig.data[k].line.color = ray_color
line_fig.data[k].line.width = ray_linewidth
return line_fig
[docs]
def ray_paths(rays,ray_color="#9673A6",ray_linewidth=3):
"""
Generate Plotly line objects for multiple ray paths, grouped by path length.
Args:
rays (list[torch.Tensor]): List of rays to plot.
ray_color (str): Color of the rays.
ray_linewidth (float): Line width of the rays.
Returns:
list: List of Plotly line objects for each ray group.
"""
ray_color = mcolors.to_hex(ray_color)
data = []
if not rays is None:
ray_path_bins = {}
for elem in rays:
if not len(elem) in ray_path_bins.keys():
ray_path_bins[len(elem)] = []
ray_path_bins[len(elem)] += [elem]
for key in ray_path_bins.keys():
line_fig = ray_paths_one_bin(ray_path_bins[key],ray_color,ray_linewidth)
data += [*line_fig.data]
return data
[docs]
def surface(transformation,surface,name,aperture_radius,resolution,colorscale,is_square=False):
"""
Generate a Plotly surface plot for an optical element.
Args:
transformation: Transformation object for the optical element.
surface: Surface object to plot.
name (str): Name of the surface.
aperture_radius (float): Radius of the aperture.
resolution (int): Resolution for the surface plot.
colorscale (list): Color scale for the surface plot.
is_square (bool): Whether the aperture is square or circular.
Returns:
list: List of Plotly surface objects.
"""
_x = torch.linspace(-aperture_radius,aperture_radius,resolution)
_y = torch.linspace(-aperture_radius,aperture_radius,resolution)
mesh = torch.meshgrid(_x,_y)
x = mesh[0].reshape(-1)
y = mesh[1].reshape(-1)
O = torch.zeros((x.shape[0],3))
if not is_square:
mul = (torch.sqrt(x*x+y*y)>aperture_radius).float()/torch.sqrt(x*x+y*y)*aperture_radius
mul += (torch.sqrt(x*x+y*y)<aperture_radius).float()
x = x*mul
y = y*mul
O[:,0] = x
O[:,1] = y
z = None
with torch.no_grad():
z = surface.explicit(O)
z = z.detach().reshape(-1)
x = x.detach().reshape(-1)
y = y.detach().reshape(-1)
v = torch.zeros((x.shape[0],4))
v[:,0] = x
v[:,1] = y
v[:,2] = z
v[:,3] = torch.ones_like(v[:,3])
Mv = None
with torch.no_grad():
M = transformation.get_transformation_matrix().detach()
Mv = v@M.T
x = Mv[:,0].reshape(_x.shape[0],_x.shape[0])
y = Mv[:,1].reshape(_x.shape[0],_x.shape[0])
z = Mv[:,2].reshape(_x.shape[0],_x.shape[0])
data = []
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name,colorscale=colorscale)]
return data
[docs]
def get_optical_system_layout(show_grid,xlabel="x [mm]",ylabel="y [mm]",zlabel="z [mm]",xticks=None,yticks=None,zticks=None,axislabel_font_size=10,tick_font_size=10):
"""
Create a Plotly layout for 3D visualization of the optical system.
Args:
show_grid (bool): Whether to show the grid.
xlabel (str): Label for the x-axis.
ylabel (str): Label for the y-axis.
zlabel (str): Label for the z-axis.
xticks (list[float], optional): Custom x-ticks.
yticks (list[float], optional): Custom y-ticks.
zticks (list[float], optional): Custom z-ticks.
axislabel_font_size (int): Font size for axis labels.
tick_font_size (int): Font size for tick labels.
Returns:
plotly.graph_objs.Layout: Layout object for the plot.
"""
#TODO write wrapper for plot3D!
camera = dict(
up=dict(x=1., y=0., z=0)
)
xaxis=dict(
visible=show_grid,
title=dict(text=xlabel, font=dict(size=axislabel_font_size)), # X axis title font size
tickfont=dict(size=tick_font_size) # X axis tick labels font size
)
yaxis=dict(
visible=show_grid,
title=dict(text=ylabel, font=dict(size=axislabel_font_size)), # Y axis title font size
tickfont=dict(size=tick_font_size) # Y axis tick labels font size
)
zaxis=dict(
visible=show_grid,
title=dict(text=zlabel, font=dict(size=axislabel_font_size)), # Z axis title font size
tickfont=dict(size=tick_font_size) # Z axis tick labels font size
)
if xticks is not None:
xaxis["tickvals"] = xticks
if yticks is not None:
yaxis["tickvals"] = yticks
if zticks is not None:
zaxis["tickvals"] = zticks
scene = dict(
xaxis=xaxis,
yaxis=yaxis,
zaxis=zaxis,
aspectmode='data',
aspectratio = dict(x=1, y=1, z=1),
)
""" scene = dict(
xaxis = dict(visible=show_axis),
yaxis = dict(visible=show_axis),
zaxis = dict(visible=show_axis),
aspectmode='data',
aspectratio = dict(x=1, y=1, z=1),
xaxis_title='x [mm]',
yaxis_title='y [mm]',
zaxis_title='z [mm]')
"""
layout = go.Layout(scene_camera=camera,scene=scene)
return layout
def _plot_surface(surface,name,resolution):
"""
Generate Plotly surface objects for all 3D surface segments of an element.
Args:
surface: Object with get_plot_points_3D and get_plotly_color_scale methods.
name (str): Name for the surface.
resolution (int): Resolution for the surface plot.
Returns:
list: List of Plotly surface objects.
"""
surface_list = surface.get_plot_points_3D(resolution)
if len(surface_list)==0:
return []
colorscale = surface.get_plotly_color_scale()
data = []
for k,(x,y,z) in enumerate(surface_list):
try:
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[k])]
except:
print("Wrong number of colorscales or colorscales is not correct, fallback to first colorscale!")
data += [go.Surface(x=x, y=y, z=z,showscale=False,name=name+f"_{k}",colorscale=colorscale[0])]
return data
def _plot_surface_recursively(current_elem,name,resolution):
"""
Recursively generate Plotly surface objects for an element and its plotable children.
Args:
current_elem: The current plotable element.
name (str): Name for the element.
resolution (int): Resolution for the surface plot.
Returns:
list: List of Plotly surface objects for the element and its children.
"""
out = _plot_surface(current_elem,name,resolution)
for elem,elem_name in current_elem.get_plotable_childs():
out += _plot_surface_recursively(elem,elem_name,resolution)
return out
[docs]
def plot(element=None,
rays=None,
resolution=32,
show_grid=True,
xlabel="x [mm]",
ylabel="y [mm]",
zlabel="z [mm]",
xticks=None,
yticks=None,
zticks=None,
axislabel_font_size=10,
tick_font_size=10,
ray_color="#9673A6",
ray_linewidth=3.,
show=True,
html_file_name=None):
"""
Visualize the optical system and ray paths in 3D using Plotly.
Args:
element: The optical system element to plot (must implement Plotable interface).
rays (list[torch.Tensor] or dict): List of rays or dict containing ray paths.
resolution (int): Resolution for the surface plot.
show_grid (bool): Whether to show the grid.
xlabel (str): Label for the x-axis.
ylabel (str): Label for the y-axis.
zlabel (str): Label for the z-axis.
xticks (list[float], optional): Custom x-ticks.
yticks (list[float], optional): Custom y-ticks.
zticks (list[float], optional): Custom z-ticks.
axislabel_font_size (int): Font size for axis labels.
tick_font_size (int): Font size for tick labels.
ray_color (str): Color of the rays.
ray_linewidth (float): Line width of the rays.
show (bool): Whether to display the plot immediately.
html_file_name (str, optional): If provided, saves the plot as an HTML file.
Returns:
plotly.graph_objs.Figure or None: The Plotly figure object if show is False, otherwise None.
"""
data = []
if isinstance(element,(list,tuple)):
for subelem in element:
subelem = copy.deepcopy(subelem)
subelem = subelem.to("cpu")
data += _plot_surface_recursively(subelem,"",resolution)
elif not element is None:
element = copy.deepcopy(element)
element = element.to("cpu")
data += _plot_surface_recursively(element,"",resolution)
if not rays is None:
if isinstance(rays,dict):
rays = rays["ray_paths"]
rays = [elem.cpu() for elem in rays]
data += ray_paths(rays,ray_color,ray_linewidth)
layout = get_optical_system_layout(show_grid,xlabel,ylabel,zlabel,xticks,yticks,zticks,axislabel_font_size,tick_font_size)
fig = go.Figure(data=data,layout=layout)
if show:
fig.show()
if not html_file_name is None:
if html_file_name[-5:]!=".html":
raise RuntimeError("html_file_name should end with .html!")
pio.write_html(fig, file=html_file_name, auto_open=False)
if not show:
return fig