# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.
__all__ = [
"gaussian_func_1D",
"gaussian_func_2D",
"calc_smooth_desired_irradiance",
"GaussianSmoother",
"make_evaluation_function",
"make_merit_function",
"GaussianSmootherSquare"
]
import torch
import math
import numpy as np
from .integrators import Cube
import torch
from typing import Callable,List,Tuple,Optional
import gc
import torch
import numpy as np
from .element import trace_to_detector,SequentialOpticalSystem
import math
from .source import LightSource
from .target_grid import Grid
import gc
import warnings
from .render import binned_irradiance
from torchmetrics.image import StructuralSimilarityIndexMeasure
#pip install torchmetrics
[docs]
def gaussian_func_1D(eval_points:torch.Tensor,
x_range,
num_gauss_points:int,
sigma:float,
include_boundary=False)->torch.Tensor:
"""
Gaussian function for 1D convolution.
Args:
eval_points (torch.Tensor): Points where the Gaussian function is evaluated.
x_range (tuple): Range of the target plane.
num_gauss_points (int): Number of Gaussian points.
sigma (float): Standard deviation of the Gaussian function.
include_boundary (bool): Whether to include the boundary points.
Returns:
torch.Tensor: Evaluated Gaussian function.
"""
device = eval_points.device
dtype = eval_points.dtype
eval_points = eval_points.reshape(-1)
xgrid = None
if include_boundary:
xgrid = torch.linspace(x_range[0],x_range[1],num_gauss_points,dtype=dtype,device=device)
else:
xgrid = torch.linspace(x_range[0],x_range[1],num_gauss_points+1,dtype=dtype,device=device)
dxgrid = xgrid[1]-xgrid[0]
xgrid = xgrid[:-1]
xgrid = xgrid + dxgrid*0.5
dist = (xgrid.reshape(-1,1)-eval_points.reshape(1,-1))
const = 1.0/math.sqrt((2.0*math.pi)*sigma*sigma)
multiplier = const
out = multiplier*torch.exp(-(dist**2.0/(2.0*(sigma**2.0))))
return out
[docs]
def gaussian_func_2D(eval_points:torch.Tensor,
x_range,
y_range,
x_grid_size:int,
y_grid_size:int,
sigma:float|torch.Tensor,
val_multi:torch.Tensor|None=None,
summed:bool=True,
include_boundary=False)->torch.Tensor:
"""
Gaussian function for 2D convolution.
Args:
eval_points (torch.Tensor): Points where the Gaussian function is evaluated.
y_range (tuple): Range of the target plane in the vertical direction.
x_range (tuple): Range of the target plane in the horizontal direction.
y_grid_size (int): Number of Gaussian points in the vertical direction.
x_grid_size (int): Number of Gaussian points in the horizontal direction.
sigma (float): Standard deviation of the Gaussian function.
val_multi (torch.Tensor|None): Optional multiplier for the Gaussian function.
summed (bool): Whether to sum the Gaussian function.
include_boundary (bool): Whether to include the boundary points.
Returns:
torch.Tensor: Evaluated Gaussian function.
"""
if eval_points.shape[-1] != 2:
raise RuntimeError("points need to be in local coordinates and shape [numraysx2]")
eval_points1 = eval_points[:,1]
eval_points2 = eval_points[:,0]
out1 = gaussian_func_1D(eval_points1,y_range,y_grid_size,sigma,include_boundary)
out2 = gaussian_func_1D(eval_points2,x_range,x_grid_size,sigma,include_boundary)
if not val_multi is None:
out1 = out1*val_multi.reshape(1,-1)
if summed is False:
print("summed=False should only be used when debugging!! Its very very slow.")
out = out1.reshape(y_grid_size,1,-1)*out2.reshape(1,x_grid_size,-1)
return out
return out1@out2.T
[docs]
def calc_smooth_desired_irradiance(desired_irradiance_fun:Callable,
x_range:List[float],
y_range:List[float],
x_grid_size:int,
y_grid_size:int,
sigma:float,
num_integration_points:int,
num_splits=5,
dtype=torch.get_default_dtype(),
device=torch.get_default_device())->torch.Tensor:
"""
Calculates the smoothed desired irradiance using Gaussian convolution.
Args:
desired_irradiance_fun (Callable): Function that computes the desired irradiance at given points.
x_range (List[float]): Range of the target plane in the x direction [min, max].
y_range (List[float]): Range of the target plane in the y direction [min, max].
x_grid_size (int): Number of pixels in the x direction.
y_grid_size (int): Number of pixels in the y direction.
sigma (float): Standard deviation of the Gaussian kernel.
num_integration_points (int): Number of integration points for numerical integration.
num_splits (int, optional): Number of splits for integration to reduce memory usage. Defaults to 5.
dtype (torch.dtype, optional): Data type for tensors. Defaults to torch.get_default_dtype().
device (torch.device, optional): Device for computation. Defaults to torch.get_default_device().
Returns:
torch.Tensor: Smoothed desired irradiance map.
"""
gc.collect()
integrator = Cube([x_range,y_range])
points,weights = integrator.sample(num_integration_points,"sobol_pow2")
splitted_points = torch.split(points, num_integration_points // num_splits)
splitted_weights = torch.split(weights, num_integration_points // num_splits)
with torch.no_grad():
out = []
for k in range(num_splits):
split_points = splitted_points[k].to(device=device,dtype=dtype)
split_weights = splitted_weights[k].to(device=device,dtype=dtype)
tmp = gaussian_func_2D(split_points,x_range,y_range,x_grid_size,y_grid_size,sigma=sigma,val_multi=desired_irradiance_fun(split_points)*split_weights)
#print("tmp.shape",tmp.shape)
out.append(tmp.detach())
#del split_points, split_weights
gc.collect()
out = torch.sum(torch.stack(out), dim=0)
del points, weights, splitted_points, splitted_weights
gc.collect()
#print("smoothed desired sum",out.sum())
return out
[docs]
class GaussianSmoother():
r"""
The GaussianSmoother class implements gaussian measurement functions but also computes smoothed desired irradiance distributions. For more information on this class please refer to the examples.
Args:
x_range (list): Range of the target plane in the x direction [min, max].
y_range (list): Range of the target plane in the y direction [min, max].
x_grid_size (int): Number of pixels in the x direction.
y_grid_size (int): Number of pixels in the y direction.
sigma (float): Standard deviation of the Gaussian kernel.
desired_irradiance_fun (Callable): Function that computes the desired irradiance at given points.
smoothed_num_integration_points (int): Number of integration points for smoothing.
smoothed_num_splits (int): Number of splits for integration to reduce memory usage.
dtype (torch.dtype, optional): Data type for tensors. Defaults to torch.get_default_dtype().
device (torch.device, optional): Device for computation. Defaults to torch.get_default_device().
Attributes:
x_grid_size (int): Number of pixels in the x direction.
y_grid_size (int): Number of pixels in the y direction.
sigma (float): Standard deviation of the Gaussian kernel.
include_boundary (bool): Whether to include boundary points in the grid.
x_range (list): Range of the target plane in the x direction.
y_range (list): Range of the target plane in the y direction.
grid (Grid): Grid object for pixel centers.
discrete_desired_irradiance (torch.Tensor): Desired irradiance at pixel centers.
smoothed_desired_irradiance (torch.Tensor): Smoothed desired irradiance map.
"""
def __init__(self,
x_range:list,
y_range:list,
x_grid_size:int,
y_grid_size:int,
sigma:float,
desired_irradiance_fun:Callable,
smoothed_num_integration_points:int,
smoothed_num_splits:int,
dtype=torch.get_default_dtype(),
device=torch.get_default_device()):
self.x_grid_size,self.y_grid_size = x_grid_size,y_grid_size
self.sigma = sigma
self.include_boundary = False
self.x_range,self.y_range = x_range,y_range
self.grid = Grid(x_range,y_range,x_grid_size,y_grid_size)
centers = self.grid.get_pixel_centers().reshape(-1,2)
self.discrete_desired_irradiance:torch.Tensor = desired_irradiance_fun(centers).reshape(y_grid_size,x_grid_size)
integrated_desired_irradiance = self.integrate_values(self.discrete_desired_irradiance)
self.discrete_desired_irradiance = self.discrete_desired_irradiance / integrated_desired_irradiance
new_desired_irradiance_fun = lambda points: desired_irradiance_fun(points) / integrated_desired_irradiance
self.new_desired_irradiance_fun = new_desired_irradiance_fun
self.smoothed_num_integration_points = smoothed_num_integration_points
self.smoothed_num_splits = smoothed_num_splits
self.dtype = dtype
self.device = device
self.smoothed_desired_irradiance:torch.Tensor = calc_smooth_desired_irradiance(new_desired_irradiance_fun,
x_range,y_range,
x_grid_size,
y_grid_size,
sigma=sigma,
num_integration_points=smoothed_num_integration_points,
num_splits=smoothed_num_splits,
dtype=dtype,
device=device)
[docs]
def smoothed_irradiance(self,points:torch.Tensor,ray_multi:torch.Tensor,x_range=None,y_range=None)->torch.Tensor:
"""
Computes the smoothed irradiance at given points using a Gaussian kernel.
Args:
points (torch.Tensor): Array of points where the irradiance is evaluated, shape [N, 2].
ray_multi (torch.Tensor): Multiplicative weights for each point, e.g., ray flux.
x_range (tuple, optional): Range of the target plane in the x direction. Defaults to None.
y_range (tuple, optional): Range of the target plane in the y direction. Defaults to None.
Returns:
torch.Tensor: Smoothed irradiance values at the specified points.
"""
if x_range is None:
x_range = self.x_range
if y_range is None:
y_range = self.y_range
return gaussian_func_2D(points,x_range,y_range,self.x_grid_size,self.y_grid_size,self.sigma,ray_multi,include_boundary=self.include_boundary)
[docs]
def none_smoothed_irradiance(self,points:torch.Tensor,ray_multi:torch.Tensor)->torch.Tensor:
"""
Computes the non-smoothed irradiance at given points by summing ray contributions in each grid cell.
Args:
points (torch.Tensor): Array of points where the irradiance is evaluated, shape [N, 2].
ray_multi (torch.Tensor): Multiplicative weights for each point, e.g., ray flux.
Returns:
torch.Tensor: Non-smoothed irradiance values at the specified grid cells.
"""
irradiance = self.grid.sum(points,ray_multi)/self.grid.get_pixel_area()
return irradiance
[docs]
def integrate_values(self, vals:torch.Tensor,x_range=None,y_range=None)->torch.Tensor:
"""
Integrates the provided values over the grid using midpoint rule.
Args:
vals (torch.Tensor): Values to integrate, typically irradiance or residuals, shape matching the grid.
Returns:
torch.Tensor: The integrated sum over the grid.
"""
if x_range is None:
x_range = self.x_range
if y_range is None:
y_range = self.y_range
integrator = Cube([x_range, y_range])
_, weights = integrator.sample([self.x_grid_size, self.y_grid_size], "midpoint")
weights = weights.to(device=vals.device, dtype=vals.dtype)
vals = vals.reshape(-1)
return (vals * weights).sum()
[docs]
class GaussianSmootherSquare(GaussianSmoother):
r"""
This class is a specialized version of GaussianSmoother for cases where the x and y ranges are identical,
and the grid is square (same number of pixels in both directions).
Args:
x_range (list): Range of the target plane in both x and y directions [min, max].
x_grid_size (int): Number of pixels in both x and y directions.
sigma (float): Standard deviation of the Gaussian kernel.
desired_irradiance_fun (Callable): Function that computes the desired irradiance at given points.
smoothed_num_integration_points (int): Number of integration points for smoothing.
smoothed_num_splits (int): Number of splits for integration to reduce memory usage.
dtype (torch.dtype, optional): Data type for tensors. Defaults to torch.get_default_dtype().
device (torch.device, optional): Device for computation. Defaults to torch.get_default_device().
"""
def __init__(self,
aperture_radius:list,
grid_size:int,
sigma:float,
desired_irradiance_fun:Callable,
smoothed_num_integration_points:int,
smoothed_num_splits:int,
dtype=torch.get_default_dtype(),
device=torch.get_default_device()):
super().__init__(x_range=[-aperture_radius,aperture_radius],y_range=[-aperture_radius,aperture_radius],
x_grid_size=grid_size,
y_grid_size=grid_size,
sigma=sigma,
desired_irradiance_fun=desired_irradiance_fun,
smoothed_num_integration_points=smoothed_num_integration_points,
smoothed_num_splits=smoothed_num_splits,
dtype=dtype,
device=device)
[docs]
def make_evaluation_function(optical_system:SequentialOpticalSystem,
sequence:List,
source:LightSource,
detector,
smoother:GaussianSmoother,
num_splits:int=10,
num_rays_per_split:int=100000,
method_ray_tracing="monte_carlo",
device=torch.get_default_device())->Callable:
"""
Creates an evaluation function for comparing simulated and desired irradiance.
Args:
optical_system (SequentialOpticalSystem): The optical system to be used for ray tracing.
sequence: The sequence of optical elements.
source (LightSource): The light source for the simulation.
detector: The detector object.
smoother (GaussianSmoother): Smoother object for irradiance comparison.
num_splits (int, optional): Number of splits for ray tracing to reduce memory usage. Defaults to 10.
num_rays_per_split (int, optional): Number of rays per split. Defaults to 1,000,000.
method_ray_tracing (str, optional): Ray tracing method ('monte_carlo', etc.). Defaults to "monte_carlo".
device (torch.device, optional): Device for computation. Defaults to torch.get_default_device().
Returns:
Callable: A function that computes the L2 error between simulated and desired irradiance.
"""
smoother.x_range
L = smoother.x_range[1]-smoother.x_range[0]
maxirr_est = (1/(L**2))*10
ssim = StructuralSimilarityIndexMeasure(data_range=maxirr_est)
def evaluate():
raycounting_list = []
for k in (range(num_splits)):
tmp = binned_irradiance(optical_system=optical_system,sequence=sequence,source=source,detector=detector,grid=smoother.grid,num_rays=num_rays_per_split,method_ray_tracing=method_ray_tracing,device=device)
tmp = tmp.detach().cpu()
raycounting_list.append(tmp)
raycounting = torch.mean(torch.stack(raycounting_list),dim=0).detach().cpu()
smoother.last_raycounting = raycounting.detach().cpu()
residual = raycounting.cpu().reshape(-1)-smoother.discrete_desired_irradiance.cpu().reshape(-1)
rmse = torch.sqrt(torch.mean((raycounting.cpu().reshape(-1)-smoother.discrete_desired_irradiance.cpu().reshape(-1))**2.0))
ssim_error = ssim(raycounting.cpu().reshape(1,1,smoother.x_grid_size,smoother.x_grid_size),smoother.discrete_desired_irradiance.cpu().reshape(1,1,smoother.x_grid_size,smoother.x_grid_size))
L2_error = torch.sqrt(smoother.integrate_values(residual**2))
#RMSE = torch.sum((residual**2))
return L2_error,rmse,ssim_error
return evaluate
"""
def make_evaluation_function_rmse(optical_system:SequentialOpticalSystem,
sequence:List,
source:LightSource,
detector,
smoother:GaussianSmoother,
num_splits:int=10,
num_rays_per_split:int=100000,
method_ray_tracing="monte_carlo",
device=torch.get_default_device())->Callable:
def evaluate():
raycounting_list = []
for k in (range(num_splits)):
tmp = binned_irradiance(optical_system=optical_system,sequence=sequence,source=source,detector=detector,grid=smoother.grid,num_rays=num_rays_per_split,method_ray_tracing=method_ray_tracing,device=device)
tmp = tmp.detach().cpu()
raycounting_list.append(tmp)
raycounting = torch.mean(torch.stack(raycounting_list),dim=0).detach().cpu()
smoother.last_raycounting = raycounting.detach().cpu()
#residual = raycounting.cpu().reshape(-1)-smoother.discrete_desired_irradiance.cpu().reshape(-1)
L2_error = torch.sqrt(torch.mean((raycounting.cpu().reshape(-1)-smoother.discrete_desired_irradiance.cpu().reshape(-1))**2.0))
#L2_error = torch.sqrt(smoother.integrate_values(residual**2))
#RMSE = torch.sum((residual**2))
return L2_error
return evaluate
def make_evaluation_function_ssim(optical_system:SequentialOpticalSystem,
sequence:List,
source:LightSource,
detector,
smoother:GaussianSmoother,
num_splits:int=10,
num_rays_per_split:int=100000,
method_ray_tracing="monte_carlo",
device=torch.get_default_device())->Callable:
smoother.x_range
L = smoother.x_range[1]-smoother.x_range[0]
maxirr_est = (1/(L**2))*10
ssim = StructuralSimilarityIndexMeasure(data_range=maxirr_est)
def evaluate():
raycounting_list = []
for k in (range(num_splits)):
tmp = binned_irradiance(optical_system=optical_system,sequence=sequence,source=source,detector=detector,grid=smoother.grid,num_rays=num_rays_per_split,method_ray_tracing=method_ray_tracing,device=device)
tmp = tmp.detach().cpu()
raycounting_list.append(tmp)
raycounting = torch.mean(torch.stack(raycounting_list),dim=0).detach().cpu()
smoother.last_raycounting = raycounting.detach().cpu()
#residual = raycounting.cpu().reshape(-1)-smoother.discrete_desired_irradiance.cpu().reshape(-1)
L2_error = ssim(raycounting.cpu().reshape(-1),smoother.discrete_desired_irradiance.cpu().reshape(-1))
#L2_error = torch.sqrt(smoother.integrate_values(residual**2))
#RMSE = torch.sum((residual**2))
return L2_error
return evaluate
"""
[docs]
def make_merit_function(optical_system:SequentialOpticalSystem,
sequence:List,
source:LightSource,
detector,
smoother:GaussianSmoother,
num_rays:int,
method_ray_tracing="sobol_pow2",
use_desired_irradiance_smoothing=True,
device=torch.get_default_device(),
T_margin=None)->Callable:
"""
Creates a merit function to obtain a desired irradiance distribution for the given optical system, source, and detector.
Args:
optical_system (SequentialOpticalSystem): The optical system to be used.
sequence: The sequence of elements in the optical system.
source (LightSource): The light source to be used.
detector: The detector to be used.
num_rays (int): Number of rays to be traced.
smoother (Smoother): The smoother object for merit function calculation.
device: The device to be used for calculations.
method_ray_tracing (str): Method for ray tracing ('sobol' or 'midpoint').
use_desired_irradiance_smoothing (bool): Whether to use desired irradiance smoothing.
use_power_correction (bool): Whether to use power correction.
save_last_eval (bool): Whether to save the last evaluation.
T_margin (float|None): Optional margin for integration domain if it is None the integration domain will not be adjusted on the fly.
Returns:
Callable: A function that computes the merit value.
"""
if T_margin is not None and use_desired_irradiance_smoothing == False:
raise RuntimeError("T_margin can only be used when use_desired_irradiance_smoothing is True.")
def merit_function()->torch.Tensor:
"""
"""
x,weights,y,wl = trace_to_detector(optical_system,sequence,source,detector,num_rays,device,method_ray_tracing=method_ray_tracing)
Qval = source.get_flux(x)
#print("total energy rays:",(Qval*weights).sum())
if smoother.smoothed_desired_irradiance is None and use_desired_irradiance_smoothing == True:
raise RuntimeError("Using desired irradiance smoothing but smoothed_desired_irradiance was not provided!--calc_smooth_desired_irradiance")
smoother.smoother_rect_special = None
if T_margin is not None and use_desired_irradiance_smoothing == True:
xmin = y[:,0].min().item()
xmax = y[:,0].max().item()
ymin = y[:,1].min().item()
ymax = y[:,1].max().item()
#xlen = xmax-xmin
#ylen = ymax-ymin
#xmid = (xmax+xmin)*0.5
#ymid = (ymax+ymin)*0.5
x_range = [xmin-T_margin,xmax+T_margin]
y_range = [ymin-T_margin,ymax+T_margin]
smoother.smoother_rect_special = [x_range,y_range]
smoothed_irradiance = smoother.smoothed_irradiance(y,Qval*weights,x_range,y_range)
smoother.smoothed_desired_irradiance = calc_smooth_desired_irradiance(smoother.new_desired_irradiance_fun,
x_range,y_range,
smoother.x_grid_size,
smoother.y_grid_size,
sigma=smoother.sigma,
num_integration_points=smoother.smoothed_num_integration_points,
num_splits=smoother.smoothed_num_splits,
dtype=smoother.dtype,
device=smoother.device).detach().to(device=device)
residual = smoothed_irradiance.reshape(-1)-smoother.smoothed_desired_irradiance.reshape(-1)
smoother.last_smoothed_irradiance = smoothed_irradiance.detach().cpu()
return torch.sqrt(smoother.integrate_values(residual**2,x_range,y_range))
else:
smoothed_irradiance = smoother.smoothed_irradiance(y,Qval*weights)
smoother.last_smoothed_irradiance = smoothed_irradiance.detach().cpu()
residual = None
smoother.smoothed_desired_irradiance = smoother.smoothed_desired_irradiance.to(device=device)
smoother.discrete_desired_irradiance = smoother.discrete_desired_irradiance.to(device=device)
if use_desired_irradiance_smoothing:
#print("sums 1: ",smoothed_irradiance.sum()," 2: ",smoother.smoothed_desired_irradiance.sum())
residual = smoothed_irradiance.reshape(-1)-smoother.smoothed_desired_irradiance.reshape(-1)#.to(device=device)
else:
#print("sums 1: ",smoothed_irradiance.sum()," 2: ",smoother.discrete_desired_irradiance.sum())
residual = smoothed_irradiance.reshape(-1)-smoother.discrete_desired_irradiance.reshape(-1)#.to(device=device)
return torch.sqrt(smoother.integrate_values(residual**2))
return merit_function