Element¶
- diffinytrace.element.is_valid_square_circle(transform: Transform, O: Tensor, aperture_radius: float, is_square: bool) Tensor[source]¶
Checks whether points lie within a circular or square aperture after transformation.
- Parameters:
transform (Transform) – Transformation object to convert global to local coordinates.
O (torch.Tensor) – Points in global coordinates of shape (N, 3).
aperture_radius (float or torch.Tensor) – Radius of the circular or square aperture.
is_square (bool) – If True, aperture is square; if False, circular.
- Returns:
Boolean tensor of shape (N,) indicating whether each point lies within the aperture.
- Return type:
torch.Tensor
- class diffinytrace.element.OpticalSystem(modules_dict: Dict)[source]¶
Bases:
Module,PlotableBase class for optical systems composed of multiple optical modules.
This class serves as a container for modules such as lenses, mirrors, and detectors. It supports visualization and modular organization.
- modules_dict¶
Dictionary of named optical modules.
- Type:
nn.ModuleDict
- forward()[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_plotable_childs()[source]¶
Returns a list of all plotable child objects of this object. Each child is represented as a list containing the child object and its name.
- class diffinytrace.element.SequentialOpticalSystem(modules_dict: Dict, n_func_enviroment=RefractiveIndex())[source]¶
Bases:
OpticalSystemOptical system that processes rays in a defined sequence.
Useful for simulating light propagation through a sequence of elements, e.g., source → lens → detector.
- n_func_enviroment¶
Function returning refractive index of the surrounding medium.
- Type:
Callable
- forward(x, mapping_sequence: List[str])[source]¶
Propagates rays through the defined sequence of modules.
- Parameters:
x (Any) – Input rays or sampling data.
mapping_sequence (list[str]) – Ordered list of module names defining propagation sequence.
- Returns:
Output after final module in the sequence.
- Return type:
Any
- class diffinytrace.element.OpticalElement(fill_color='white', outline_color='black', is_volume=False)[source]¶
Bases:
PhysicalObject,PlotableAbstract base class for optical elements like lenses, mirrors, and detectors.
Provides interface for geometric transformation and ray propagation.
- forward(O2: Tensor, D2: Tensor, wl: Tensor, n_func_enviroment, meta_data)[source]¶
Propagates rays through the optical element.
- Parameters:
O2 (torch.Tensor) – Ray origins.
D2 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment (Callable) – Function returning environmental refractive index.
meta_data (dict) – Dictionary with path length and validity information.
- Raises:
NotImplementedError – Must be overridden by subclasses.
- class diffinytrace.element.OpticalSurface(transform: Transform, surface, aperture_radius: float, is_square: bool = False, fill_color: str = 'white', outline_color: str = 'black')[source]¶
Bases:
OpticalElement,PhysicalSurfaceRepresents a surface in 3D space with a defined aperture and transformation.
Supports both square and circular apertures, and provides methods for parametric sampling, CAD conversion, ray intersection, and plotting.
- surface¶
Object with a method explicit(parametric_pos) returning z-values.
- Type:
object
- aperture_radius¶
Radius of the circular or square aperture.
- Type:
float
- is_square¶
Whether the aperture is square-shaped.
- Type:
bool
- integrator¶
Integration object (Disc or Cube) for parametric sampling.
- Type:
- get_constraint_funs_leq_zero()[source]¶
Returns constraint functions used for integration and optimization over the surface.
- Returns:
List of functions f(param_pos) <= 0 indicating valid parametric regions.
- Return type:
list[Callable]
- Raises:
RuntimeError – If is_square is True (not yet implemented).
- get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]¶
Returns 2D slices through the surface (z-y plane) for plotting.
- Parameters:
resolution (int) – Number of sample points along the y-axis.
- Returns:
List of (z, y) coordinate tuples.
- Return type:
List[Tuple[torch.Tensor]]
- get_plot_points_3D(resolution: int) List[Tuple[Tensor]][source]¶
Returns 3D grid of surface points for visualization.
- Parameters:
resolution (int) – Grid resolution in x and y.
- Returns:
List of (x, y, z) meshgrids as torch tensors.
- Return type:
List[Tuple[torch.Tensor]]
- get_CAD_points(resolution: int) List[Tuple[Tensor]][source]¶
Generates a 3D surface point grid for CAD conversion.
- Parameters:
resolution (int) – Sampling resolution.
- Returns:
(x, y, z) coordinate grids for CAD modeling.
- Return type:
Tuple[torch.Tensor]
- get_CAD_face(resolution: int, tol: float = 0.001, smoothing=None, minDeg: int = 1, maxDeg: int = 3)[source]¶
Converts the surface into a CAD face using B-spline approximation.
- Parameters:
resolution (int) – Sampling resolution.
tol (float, optional) – Approximation tolerance. Defaults to 0.001.
smoothing (Optional[int]) – Smoothing value for fitting.
minDeg (int) – Minimum degree of the spline.
maxDeg (int) – Maximum degree of the spline.
- Returns:
CAD face object.
- Return type:
cadquery.Face
- parametric_sample(num_points: int, method: str = 'sobol') tuple[Tensor, Tensor][source]¶
Samples parametric positions on the aperture using the integrator.
- Parameters:
num_points (int) – Number of sample points.
method (str) – Sampling method. Options: “sobol”, “monte_carlo”, “midpoint”, etc.
- Returns:
Sampled positions and integration weights.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- parametric_surface(parametric_pos: Tensor) Tensor[source]¶
Maps 2D parametric coordinates to 3D global coordinates using the surface height and transform.
- Parameters:
parametric_pos (torch.Tensor) – 2D parametric positions of shape (N, 2).
- Returns:
3D positions of shape (N, 3) in global space.
- Return type:
torch.Tensor
- Raises:
RuntimeError – If input does not have shape […, 2].
- get_surface_and_normal_func_with_params()[source]¶
Constructs a callable for surface position and normal computation with parameter tracking.
- Returns:
Callable computes (position, normal), and the list contains parameters to be optimized.
- Return type:
Tuple[Callable, List]
- get_ray_intersect_length(O, D) Tensor[source]¶
Computes intersection length along ray until hitting the surface.
- Parameters:
O (torch.Tensor) – Ray origins of shape (N, 3).
D (torch.Tensor) – Ray directions of shape (N, 3).
- Returns:
Intersection distances t such that O + t*D lies on the surface.
- Return type:
torch.Tensor
- get_new_is_valid(O, valid) Tensor[source]¶
Updates a boolean mask indicating which rays are still valid after hitting the aperture.
- Parameters:
O (torch.Tensor) – Ray intersection points.
valid (torch.Tensor) – Previous boolean validity mask.
- Returns:
Updated validity mask.
- Return type:
torch.Tensor
- class diffinytrace.element.LensSurfaceTransmissionEnter(transform: Transform, surface, aperture_radius: float, n_func, is_square: bool = False)[source]¶
Bases:
OpticalSurface- forward(O1, D1, wl, n_func_enviroment, meta_data)[source]¶
Propagates rays through the lens entry surface.
- Parameters:
O1 (torch.Tensor) – Ray origins.
D1 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning environmental refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, wavelengths, environment function, and metadata.
- Return type:
Tuple
- class diffinytrace.element.LensSurfaceTransmissionLeave(transform: Transform, surface, aperture_radius: float, n_func, is_square: bool = False)[source]¶
Bases:
OpticalSurface- forward(O2, D2, wl, n_func_enviroment, meta_data)[source]¶
Propagates rays through the lens exit surface.
- Parameters:
O2 (torch.Tensor) – Ray origins.
D2 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning environmental refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, wavelengths, environment function, and metadata.
- Return type:
Tuple
- class diffinytrace.element.LensSurfaceSide(surface1: PhysicalSurface, surface2: PhysicalSurface, aperture_radius, is_square: bool)[source]¶
Bases:
PhysicalSurface,PlotableNon-optical surface connecting two curved lens surfaces for visualization.
Used to render the full 3D body of the lens.
- surface1¶
First lens surface.
- Type:
- surface2¶
Second lens surface.
- Type:
- aperture_radius¶
Radius or half-width of aperture.
- Type:
float
- is_square¶
Whether aperture is square.
- Type:
bool
- parametric_sample(num_points: int, method: str = 'sobol')[source]¶
Samples parametric positions on the lens side surface.
- Parameters:
num_points (int) – Number of sample points.
method (str) – Sampling method (“sobol”, “monte_carlo”, “midpoint”).
- Returns:
Sampled positions and integration weights.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- Raises:
RuntimeError – If unsupported method is provided.
- parametric_surface(parametric_pos: Tensor) Tensor[source]¶
Maps parametric coordinates to 3D global coordinates for the lens side.
- Parameters:
parametric_pos (torch.Tensor) – Parametric positions of shape (N, 2).
- Returns:
3D positions of shape (N, 3).
- Return type:
torch.Tensor
- get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]¶
Returns 2D slices through the surface (z-y plane) for plotting.
- Parameters:
resolution (int) – Number of sample points along the y-axis.
- Returns:
List of (z, y) coordinate tuples.
- Return type:
List[Tuple[torch.Tensor]]
- class diffinytrace.element.Lens(transform: Transform, lens_thickness: float, surface1, surface2, n_func, aperture_radius: float, is_square=False)[source]¶
Bases:
OpticalElementRepresents a transmissive lens consisting of two refractive surfaces.
The lens is modeled as a sequence of: - Entry surface (refraction from external medium into the lens) - Exit surface (refraction from lens into external medium) - Side surface (purely for visualization)
In our implementation, lenses consist of two explicit surfaces, a transformation matrix \(M\), a lens thickness, an aperture radius, and a material. When the lens is initialized, one can also optionally specify whether the lens is round or square. If the keyword is_square is not specified, the lens will default to being round.
Example
Below is an example of initializing a square lens:
>>> import diffinytrace as dit >>> aperture_half = 30. >>> lens_thickness = 8. >>> material = dit.materials["NBK7"] >>> transform = dit.transforms.Identity() >>> bspline = dit.Bspline(aperture_half, [3, 3], [8, 8]) >>> plane = dit.Plane() >>> lens = dit.Lens(transform, lens_thickness, >>> bspline, plane, >>> material, aperture_half, is_square=True)
- n_func¶
Function mapping wavelength to refractive index of the lens material.
- Type:
Callable
- lens_thickness¶
Learnable thickness of the lens.
- Type:
torch.nn.Parameter
- surface1¶
Entry surface.
- surface2¶
Exit surface.
- lens_surface_side¶
Side surface (for 3D rendering).
- Type:
- aperture_radius¶
Radius (or half-width) of aperture.
- Type:
float
- is_square¶
Whether the aperture is square.
- Type:
bool
- get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]¶
Returns 2D slices through the lens for plotting.
- Parameters:
resolution (int) – Number of sample points.
- Returns:
List of (z, y) coordinate tuples.
- Return type:
List[Tuple[torch.Tensor]]
- get_plot_points_3D(resolution: int) List[Tuple[Tensor]][source]¶
Returns 3D grid of lens surface points for visualization.
- Parameters:
resolution (int) – Grid resolution.
- Returns:
List of (x, y, z) meshgrids.
- Return type:
List[Tuple[torch.Tensor]]
- get_plotly_color_scale() List[source]¶
Returns color scale for plotly visualization.
- Returns:
Color scale values.
- Return type:
List
- get_plotable_childs() List[source]¶
Returns plotable child elements.
- Returns:
List of child elements.
- Return type:
List
- forward(O1: Tensor, D1: Tensor, wl: Tensor, n_func_enviroment, meta_data)[source]¶
Simulates light passing through the lens.
- Parameters:
O1 (torch.Tensor) – Ray origin positions.
D1 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment (Callable) – Function returning external medium refractive index.
meta_data (dict) – Ray metadata (PL, OPL, paths, valid).
- Returns:
Updated ray origins, directions, etc.
- Return type:
Tuple[torch.Tensor]
- class diffinytrace.element.Mirror(transform, surface, aperture_radius, is_square=False)[source]¶
Bases:
OpticalSurfaceReflective optical element that reflects rays according to the law of reflection.
Visualization is colored in a warm gold tone.
- Inherits:
OpticalSurface: Full support for surface transformation and intersection.
- forward(O1, D1, wl, n_func_enviroment, meta_data)[source]¶
Propagates rays through the mirror surface.
- Parameters:
O1 (torch.Tensor) – Ray origins.
D1 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning environmental refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, wavelengths, environment function, and metadata.
- Return type:
Tuple
- class diffinytrace.element.Detector(transform, surface, aperture_radius, is_square=True)[source]¶
Bases:
OpticalSurfaceRepresents a terminal optical element that collects ray data.
Detectors consist of an explicit surface, a transformation matrix \(M\), and an aperture radius. The detector class represents a target surface used to track the rays that hit it. When the detector is initialized, one can also optionally specify whether the detector is round or square. If the keyword is_square is not specified, the detector defaults to being square.
Example
Below is an example of how to initialize a detector:
>>> import diffinytrace as dit >>> aperture_half = 30. >>> transform = dit.transforms.Identity() >>> plane = dit.Plane() >>> detector = dit.Detector(transform, plane, >>> aperture_half, is_square=False)
- forward(O1, D1, wl, n_func_enviroment, meta_data)[source]¶
Captures the final ray interaction without altering its direction.
- Parameters:
O1 (torch.Tensor) – Ray origin.
D1 (torch.Tensor) – Ray direction.
wl (torch.Tensor) – Wavelength.
n_func_enviroment (Callable) – Function for surrounding medium.
meta_data (dict) – Ray tracing metadata.
- Returns:
Final ray data.
- Return type:
Tuple[torch.Tensor]
- diffinytrace.element.trace_to_detector(optical_system: SequentialOpticalSystem, sequence: List, source, detector: Detector, num_rays: int, device=device(type='cpu'), method_ray_tracing: str = 'sobol_pow2')[source]¶
Traces rays through a system to a detector and returns the impact coordinates.
- Parameters:
optical_system (SequentialOpticalSystem) – Ray-tracing pipeline.
sequence (list[str]) – Ordered names of system modules.
source – Source object with .sample() method.
detector (Detector) – Final surface to collect rays.
num_rays (int) – Number of rays to simulate.
device – Torch device (CPU/GPU).
method_ray_tracing (str) – Sampling method for source rays.
- Returns:
(input samples, weights, detector plane hits, wavelengths)
- Return type:
Tuple[torch.Tensor]
- diffinytrace.element.set_unused_params_to_zero(optical_system: SequentialOpticalSystem, sequence, source, params, num_rays=200000, method_ray_tracing='sobol')[source]¶
Sets unused parameters (those with zero gradient across ray paths) to zero.
- Parameters:
optical_system (SequentialOpticalSystem) – Full system.
sequence (list) – Ordered module names.
source – Ray source.
params (list[torch.nn.Parameter] or torch.nn.Parameter) – Parameters to clean.
num_rays (int) – Ray sample count.
method_ray_tracing (str) – Sampling method.
- diffinytrace.element.get_unused_params_mask(optical_system: SequentialOpticalSystem, sequence: List[str], source, params, num_rays: int = 100000, method_ray_tracing='sobol') List[BoolTensor][source]¶
Returns a boolean mask identifying which parameters are unused in the ray tracing process.
- Parameters:
optical_system (SequentialOpticalSystem) – Full system.
sequence (list) – Ordered module names.
source – Ray source.
params (list[torch.nn.Parameter]) – Parameter list.
num_rays (int) – Number of rays to test.
method_ray_tracing (str) – Sampling method.
- Returns:
Masks of the same shape as each parameter.
- Return type:
list[torch.BoolTensor]
- diffinytrace.element.set_used_params_bounds_to_constant(optical_system, sequence, source, params, bounds_attr_name_new, bounds_attr_name_old='bounds', num_rays=100000, method_ray_tracing='sobol')[source]¶
Locks unused parameters by copying their current value as bounds, making them constant.
- Parameters:
bounds_attr_name_new (str) – Name of the new bounds attribute to write.
bounds_attr_name_old (str) – Name of the original bounds attribute.
- class diffinytrace.element.FresnelOpticalSurface(transform, surface, aperture_radius, surface_derivative_x, surface_derivative_y, is_square=False)[source]¶
Bases:
OpticalSurface
- class diffinytrace.element.FresnelVirtualLensSurfaceTransmissionEnter(transform, surface, aperture_radius, n_func, surface_derivative_x, surface_derivative_y, is_square=False)[source]¶
Bases:
FresnelOpticalSurface- forward(O1, D1, wl, n_func_enviroment, meta_data)[source]¶
Propagates rays through the Fresnel lens entry surface.
- Parameters:
O1 (torch.Tensor) – Ray origins.
D1 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning environmental refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, wavelengths, environment function, and metadata.
- Return type:
Tuple
- class diffinytrace.element.FresnelVirtualLensSurfaceTransmissionLeave(transform, surface, aperture_radius, n_func, surface_derivative_x, surface_derivative_y, is_square=False)[source]¶
Bases:
FresnelOpticalSurface- forward(O2, D2, wl, n_func_enviroment, meta_data)[source]¶
Propagates rays through the Fresnel lens exit surface.
- Parameters:
O2 (torch.Tensor) – Ray origins.
D2 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning environmental refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, wavelengths, environment function, and metadata.
- Return type:
Tuple
- class diffinytrace.element.FresnelVirtualLens(transform, lens_thickness, surface1, surface2, n_func, aperture_radius, surface1_derivative_x=None, surface1_derivative_y=None, surface2_derivative_x=None, surface2_derivative_y=None, is_square=False)[source]¶
Bases:
OpticalElement- get_plot_points_2D(resolution: int)[source]¶
Returns a list of 2D plot points for the object.
- Parameters:
resolution (int) – The resolution for the plot points.
- Returns:
A list of 2D plot points.
- Return type:
list
- get_plot_points_3D(resolution)[source]¶
Returns 3D grid of Fresnel lens surface points for visualization.
- Parameters:
resolution (int) – Grid resolution.
- Returns:
List of (x, y, z) meshgrids.
- Return type:
List[Tuple[torch.Tensor]]
- get_plotly_color_scale()[source]¶
Returns color scale for plotly visualization.
- Returns:
Color scale values.
- Return type:
List
- get_plotable_childs() List[source]¶
Returns plotable child elements.
- Returns:
List of child elements.
- Return type:
List
- forward(O1: Tensor, D1: Tensor, wl: Tensor, n_func_enviroment, meta_data) Tensor[source]¶
Simulates light passing through the Fresnel lens.
- Parameters:
O1 (torch.Tensor) – Ray origin positions.
D1 (torch.Tensor) – Ray directions.
wl (torch.Tensor) – Wavelengths.
n_func_enviroment – Function returning external medium refractive index.
meta_data (dict) – Ray metadata.
- Returns:
Updated ray origins, directions, etc.
- Return type:
Tuple
- diffinytrace.element.compute_reflected_directions(D: Tensor, N: Tensor) Tensor[source]¶
Computes reflected ray directions using the reflection law.
- Parameters:
D (torch.Tensor) – Incident directions of shape (M, 3), normalized.
N (torch.Tensor) – Surface normals at points of incidence, shape (M, 3).
- Returns:
Reflected directions of shape (M, 3).
- Return type:
torch.Tensor
- diffinytrace.element.get_refracted_directions(D: Tensor, N: Tensor, n1: Tensor | float, n2: Tensor | float) Tensor[source]¶
Computes refracted ray directions using Snell’s law.
At material interfaces, the transmitted direction \(\mathbf{D'}\) is computed based on the surface normal \(\mathbf{N} = \nabla s / \|\nabla s\|\) and the incident direction \(\mathbf{D}\), using Snell’s law (see [WCH22]):
\[\mathbf{D'} = \mathbf{N} \sqrt{1 - (1 - \cos^2 \psi_i) \eta^2} + \eta (\mathbf{D} - \mathbf{N} \cos \psi_i),\]where \(\cos \psi_i = \mathbf{D} \cdot \mathbf{N}\) and \(\eta = n / n'\) is the ratio of the refractive indices of the two materials.
- Parameters:
D (torch.Tensor) – Incident directions of shape (M, 3), normalized.
N (torch.Tensor) – Surface normals at points of incidence, shape (M, 3).
n1 (float or torch.Tensor) – Refractive index of the incident medium.
n2 (float or torch.Tensor) – Refractive index of the transmission medium.
- Returns:
Refracted directions of shape (M, 3).
- Return type:
torch.Tensor
- diffinytrace.element.set_unused_bspline_coeff_to_nearest(optical_system, sequence: list[str], source, bspline_surface, num_rays=100000, method_ray_tracing='sobol')[source]¶
Fills only the unused B-spline coefficients with the nearest used value.
This function identifies B-spline coefficients that have no influence on the ray paths (i.e., gradients are zero), and updates only those by copying the value from the closest neighboring coefficient that is used. Used coefficients remain unchanged.
This is useful for having geometry that is simple to manifacture while not tempering with the overall performance.
- Parameters:
optical_system (SequentialOpticalSystem) – The optical system used for tracing.
sequence (list[str]) – Ordered list of module names for ray propagation.
source – Ray source with a .sample() method.
bspline_surface – Surface object with a .coeff tensor.
num_rays (int, optional) – Number of rays used to detect unused coefficients. Default is 100000.
method_ray_tracing (str, optional) – Sampling method (e.g., “sobol”). Default is “sobol”.
- Raises:
RuntimeError – If all coefficients are unused — likely due to insufficient ray coverage.