Element

diffinytrace.element.is_valid_square_circle(transform: Transform, O: Tensor, aperture_radius: float, is_square: bool) Tensor[source]

Checks whether points lie within a circular or square aperture after transformation.

Parameters:
  • transform (Transform) – Transformation object to convert global to local coordinates.

  • O (torch.Tensor) – Points in global coordinates of shape (N, 3).

  • aperture_radius (float or torch.Tensor) – Radius of the circular or square aperture.

  • is_square (bool) – If True, aperture is square; if False, circular.

Returns:

Boolean tensor of shape (N,) indicating whether each point lies within the aperture.

Return type:

torch.Tensor

Note

For a square, checks if ( |x| < r ) and ( |y| < r ). For a circle, checks if ( sqrt{x^2 + y^2} < r ).

class diffinytrace.element.OpticalSystem(modules_dict: Dict)[source]

Bases: Module, Plotable

Base class for optical systems composed of multiple optical modules.

This class serves as a container for modules such as lenses, mirrors, and detectors. It supports visualization and modular organization.

modules_dict

Dictionary of named optical modules.

Type:

nn.ModuleDict

forward()[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_plotable_childs()[source]

Returns a list of all plotable child objects of this object. Each child is represented as a list containing the child object and its name.

get_plot_points_2D(resolution: int)[source]

Returns a list of 2D plot points for the object.

Parameters:

resolution (int) – The resolution for the plot points.

Returns:

A list of 2D plot points.

Return type:

list

get_plot_points_3D(resolution: int)[source]

Returns a list of 3D plot points for the object.

Parameters:

resolution (int) – The resolution for the plot points.

Returns:

A list of 3D plot points.

Return type:

list

class diffinytrace.element.SequentialOpticalSystem(modules_dict: Dict, n_func_enviroment=RefractiveIndex())[source]

Bases: OpticalSystem

Optical system that processes rays in a defined sequence.

Useful for simulating light propagation through a sequence of elements, e.g., source → lens → detector.

n_func_enviroment

Function returning refractive index of the surrounding medium.

Type:

Callable

forward(x, mapping_sequence: List[str])[source]

Propagates rays through the defined sequence of modules.

Parameters:
  • x (Any) – Input rays or sampling data.

  • mapping_sequence (list[str]) – Ordered list of module names defining propagation sequence.

Returns:

Output after final module in the sequence.

Return type:

Any

class diffinytrace.element.OpticalElement(fill_color='white', outline_color='black', is_volume=False)[source]

Bases: PhysicalObject, Plotable

Abstract base class for optical elements like lenses, mirrors, and detectors.

Provides interface for geometric transformation and ray propagation.

forward(O2: Tensor, D2: Tensor, wl: Tensor, n_func_enviroment, meta_data)[source]

Propagates rays through the optical element.

Parameters:
  • O2 (torch.Tensor) – Ray origins.

  • D2 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment (Callable) – Function returning environmental refractive index.

  • meta_data (dict) – Dictionary with path length and validity information.

Raises:

NotImplementedError – Must be overridden by subclasses.

get_transform()[source]

Returns the transformation associated with the surface.

Returns:

The local-to-global transformation object.

Return type:

Transform

class diffinytrace.element.OpticalSurface(transform: Transform, surface, aperture_radius: float, is_square: bool = False, fill_color: str = 'white', outline_color: str = 'black')[source]

Bases: OpticalElement, PhysicalSurface

Represents a surface in 3D space with a defined aperture and transformation.

Supports both square and circular apertures, and provides methods for parametric sampling, CAD conversion, ray intersection, and plotting.

surface

Object with a method explicit(parametric_pos) returning z-values.

Type:

object

aperture_radius

Radius of the circular or square aperture.

Type:

float

is_square

Whether the aperture is square-shaped.

Type:

bool

transform

Local-to-global transformation.

Type:

Transform

integrator

Integration object (Disc or Cube) for parametric sampling.

Type:

Integrator

get_constraint_funs_leq_zero()[source]

Returns constraint functions used for integration and optimization over the surface.

Returns:

List of functions f(param_pos) <= 0 indicating valid parametric regions.

Return type:

list[Callable]

Raises:

RuntimeError – If is_square is True (not yet implemented).

get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]

Returns 2D slices through the surface (z-y plane) for plotting.

Parameters:

resolution (int) – Number of sample points along the y-axis.

Returns:

List of (z, y) coordinate tuples.

Return type:

List[Tuple[torch.Tensor]]

get_plot_points_3D(resolution: int) List[Tuple[Tensor]][source]

Returns 3D grid of surface points for visualization.

Parameters:

resolution (int) – Grid resolution in x and y.

Returns:

List of (x, y, z) meshgrids as torch tensors.

Return type:

List[Tuple[torch.Tensor]]

get_CAD_points(resolution: int) List[Tuple[Tensor]][source]

Generates a 3D surface point grid for CAD conversion.

Parameters:

resolution (int) – Sampling resolution.

Returns:

(x, y, z) coordinate grids for CAD modeling.

Return type:

Tuple[torch.Tensor]

get_CAD_face(resolution: int, tol: float = 0.001, smoothing=None, minDeg: int = 1, maxDeg: int = 3)[source]

Converts the surface into a CAD face using B-spline approximation.

Parameters:
  • resolution (int) – Sampling resolution.

  • tol (float, optional) – Approximation tolerance. Defaults to 0.001.

  • smoothing (Optional[int]) – Smoothing value for fitting.

  • minDeg (int) – Minimum degree of the spline.

  • maxDeg (int) – Maximum degree of the spline.

Returns:

CAD face object.

Return type:

cadquery.Face

parametric_sample(num_points: int, method: str = 'sobol') tuple[Tensor, Tensor][source]

Samples parametric positions on the aperture using the integrator.

Parameters:
  • num_points (int) – Number of sample points.

  • method (str) – Sampling method. Options: “sobol”, “monte_carlo”, “midpoint”, etc.

Returns:

Sampled positions and integration weights.

Return type:

Tuple[torch.Tensor, torch.Tensor]

parametric_surface(parametric_pos: Tensor) Tensor[source]

Maps 2D parametric coordinates to 3D global coordinates using the surface height and transform.

Parameters:

parametric_pos (torch.Tensor) – 2D parametric positions of shape (N, 2).

Returns:

3D positions of shape (N, 3) in global space.

Return type:

torch.Tensor

Raises:

RuntimeError – If input does not have shape […, 2].

get_surface_and_normal_func_with_params()[source]

Constructs a callable for surface position and normal computation with parameter tracking.

Returns:

Callable computes (position, normal), and the list contains parameters to be optimized.

Return type:

Tuple[Callable, List]

get_ray_intersect_length(O, D) Tensor[source]

Computes intersection length along ray until hitting the surface.

Parameters:
  • O (torch.Tensor) – Ray origins of shape (N, 3).

  • D (torch.Tensor) – Ray directions of shape (N, 3).

Returns:

Intersection distances t such that O + t*D lies on the surface.

Return type:

torch.Tensor

get_new_is_valid(O, valid) Tensor[source]

Updates a boolean mask indicating which rays are still valid after hitting the aperture.

Parameters:
  • O (torch.Tensor) – Ray intersection points.

  • valid (torch.Tensor) – Previous boolean validity mask.

Returns:

Updated validity mask.

Return type:

torch.Tensor

get_transform() Transform[source]

Returns the transformation associated with the surface.

Returns:

The local-to-global transformation object.

Return type:

Transform

class diffinytrace.element.LensSurfaceTransmissionEnter(transform: Transform, surface, aperture_radius: float, n_func, is_square: bool = False)[source]

Bases: OpticalSurface

forward(O1, D1, wl, n_func_enviroment, meta_data)[source]

Propagates rays through the lens entry surface.

Parameters:
  • O1 (torch.Tensor) – Ray origins.

  • D1 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning environmental refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, wavelengths, environment function, and metadata.

Return type:

Tuple

class diffinytrace.element.LensSurfaceTransmissionLeave(transform: Transform, surface, aperture_radius: float, n_func, is_square: bool = False)[source]

Bases: OpticalSurface

forward(O2, D2, wl, n_func_enviroment, meta_data)[source]

Propagates rays through the lens exit surface.

Parameters:
  • O2 (torch.Tensor) – Ray origins.

  • D2 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning environmental refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, wavelengths, environment function, and metadata.

Return type:

Tuple

class diffinytrace.element.LensSurfaceSide(surface1: PhysicalSurface, surface2: PhysicalSurface, aperture_radius, is_square: bool)[source]

Bases: PhysicalSurface, Plotable

Non-optical surface connecting two curved lens surfaces for visualization.

Used to render the full 3D body of the lens.

surface1

First lens surface.

Type:

PhysicalSurface

surface2

Second lens surface.

Type:

PhysicalSurface

aperture_radius

Radius or half-width of aperture.

Type:

float

is_square

Whether aperture is square.

Type:

bool

parametric_sample(num_points: int, method: str = 'sobol')[source]

Samples parametric positions on the lens side surface.

Parameters:
  • num_points (int) – Number of sample points.

  • method (str) – Sampling method (“sobol”, “monte_carlo”, “midpoint”).

Returns:

Sampled positions and integration weights.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Raises:

RuntimeError – If unsupported method is provided.

parametric_surface(parametric_pos: Tensor) Tensor[source]

Maps parametric coordinates to 3D global coordinates for the lens side.

Parameters:

parametric_pos (torch.Tensor) – Parametric positions of shape (N, 2).

Returns:

3D positions of shape (N, 3).

Return type:

torch.Tensor

get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]

Returns 2D slices through the surface (z-y plane) for plotting.

Parameters:

resolution (int) – Number of sample points along the y-axis.

Returns:

List of (z, y) coordinate tuples.

Return type:

List[Tuple[torch.Tensor]]

get_plot_points_3D(resolution: int) List[Tuple[Tensor]][source]

Returns 3D grid of surface points for visualization.

Parameters:

resolution (int) – Grid resolution in x and y.

Returns:

List of (x, y, z) meshgrids as torch tensors.

Return type:

List[Tuple[torch.Tensor]]

get_plotly_color_scale()[source]

Returns color scale for plotly visualization.

Returns:

Color scale values.

Return type:

List

class diffinytrace.element.Lens(transform: Transform, lens_thickness: float, surface1, surface2, n_func, aperture_radius: float, is_square=False)[source]

Bases: OpticalElement

Represents a transmissive lens consisting of two refractive surfaces.

The lens is modeled as a sequence of: - Entry surface (refraction from external medium into the lens) - Exit surface (refraction from lens into external medium) - Side surface (purely for visualization)

In our implementation, lenses consist of two explicit surfaces, a transformation matrix \(M\), a lens thickness, an aperture radius, and a material. When the lens is initialized, one can also optionally specify whether the lens is round or square. If the keyword is_square is not specified, the lens will default to being round.

Example

Below is an example of initializing a square lens:

>>> import diffinytrace as dit
>>> aperture_half = 30.
>>> lens_thickness = 8.
>>> material = dit.materials["NBK7"]
>>> transform = dit.transforms.Identity()
>>> bspline = dit.Bspline(aperture_half, [3, 3], [8, 8])
>>> plane = dit.Plane()
>>> lens = dit.Lens(transform, lens_thickness,
>>>          bspline, plane,
>>>          material, aperture_half, is_square=True)
n_func

Function mapping wavelength to refractive index of the lens material.

Type:

Callable

_transform1

Transform for the first surface.

Type:

Transform

_transform2

Transform for the second surface.

Type:

Transform

lens_thickness

Learnable thickness of the lens.

Type:

torch.nn.Parameter

surface1

Entry surface.

Type:

LensSurfaceTransmissionEnter

surface2

Exit surface.

Type:

LensSurfaceTransmissionLeave

lens_surface_side

Side surface (for 3D rendering).

Type:

LensSurfaceSide

aperture_radius

Radius (or half-width) of aperture.

Type:

float

is_square

Whether the aperture is square.

Type:

bool

get_plot_points_2D(resolution: int) List[Tuple[Tensor]][source]

Returns 2D slices through the lens for plotting.

Parameters:

resolution (int) – Number of sample points.

Returns:

List of (z, y) coordinate tuples.

Return type:

List[Tuple[torch.Tensor]]

get_plot_points_3D(resolution: int) List[Tuple[Tensor]][source]

Returns 3D grid of lens surface points for visualization.

Parameters:

resolution (int) – Grid resolution.

Returns:

List of (x, y, z) meshgrids.

Return type:

List[Tuple[torch.Tensor]]

get_plotly_color_scale() List[source]

Returns color scale for plotly visualization.

Returns:

Color scale values.

Return type:

List

get_plotable_childs() List[source]

Returns plotable child elements.

Returns:

List of child elements.

Return type:

List

forward(O1: Tensor, D1: Tensor, wl: Tensor, n_func_enviroment, meta_data)[source]

Simulates light passing through the lens.

Parameters:
  • O1 (torch.Tensor) – Ray origin positions.

  • D1 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment (Callable) – Function returning external medium refractive index.

  • meta_data (dict) – Ray metadata (PL, OPL, paths, valid).

Returns:

Updated ray origins, directions, etc.

Return type:

Tuple[torch.Tensor]

get_transform()[source]

Returns the transformation of the lens exit surface.

Returns:

The transformation object.

Return type:

Transform

class diffinytrace.element.Mirror(transform, surface, aperture_radius, is_square=False)[source]

Bases: OpticalSurface

Reflective optical element that reflects rays according to the law of reflection.

Visualization is colored in a warm gold tone.

Inherits:

OpticalSurface: Full support for surface transformation and intersection.

forward(O1, D1, wl, n_func_enviroment, meta_data)[source]

Propagates rays through the mirror surface.

Parameters:
  • O1 (torch.Tensor) – Ray origins.

  • D1 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning environmental refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, wavelengths, environment function, and metadata.

Return type:

Tuple

class diffinytrace.element.Detector(transform, surface, aperture_radius, is_square=True)[source]

Bases: OpticalSurface

Represents a terminal optical element that collects ray data.

Detectors consist of an explicit surface, a transformation matrix \(M\), and an aperture radius. The detector class represents a target surface used to track the rays that hit it. When the detector is initialized, one can also optionally specify whether the detector is round or square. If the keyword is_square is not specified, the detector defaults to being square.

Example

Below is an example of how to initialize a detector:

>>> import diffinytrace as dit
>>> aperture_half = 30.
>>> transform = dit.transforms.Identity()
>>> plane = dit.Plane()
>>> detector = dit.Detector(transform, plane,
>>>                         aperture_half, is_square=False)
forward(O1, D1, wl, n_func_enviroment, meta_data)[source]

Captures the final ray interaction without altering its direction.

Parameters:
  • O1 (torch.Tensor) – Ray origin.

  • D1 (torch.Tensor) – Ray direction.

  • wl (torch.Tensor) – Wavelength.

  • n_func_enviroment (Callable) – Function for surrounding medium.

  • meta_data (dict) – Ray tracing metadata.

Returns:

Final ray data.

Return type:

Tuple[torch.Tensor]

diffinytrace.element.trace_to_detector(optical_system: SequentialOpticalSystem, sequence: List, source, detector: Detector, num_rays: int, device=device(type='cpu'), method_ray_tracing: str = 'sobol_pow2')[source]

Traces rays through a system to a detector and returns the impact coordinates.

Parameters:
  • optical_system (SequentialOpticalSystem) – Ray-tracing pipeline.

  • sequence (list[str]) – Ordered names of system modules.

  • source – Source object with .sample() method.

  • detector (Detector) – Final surface to collect rays.

  • num_rays (int) – Number of rays to simulate.

  • device – Torch device (CPU/GPU).

  • method_ray_tracing (str) – Sampling method for source rays.

Returns:

(input samples, weights, detector plane hits, wavelengths)

Return type:

Tuple[torch.Tensor]

diffinytrace.element.set_unused_params_to_zero(optical_system: SequentialOpticalSystem, sequence, source, params, num_rays=200000, method_ray_tracing='sobol')[source]

Sets unused parameters (those with zero gradient across ray paths) to zero.

Parameters:
  • optical_system (SequentialOpticalSystem) – Full system.

  • sequence (list) – Ordered module names.

  • source – Ray source.

  • params (list[torch.nn.Parameter] or torch.nn.Parameter) – Parameters to clean.

  • num_rays (int) – Ray sample count.

  • method_ray_tracing (str) – Sampling method.

diffinytrace.element.get_unused_params_mask(optical_system: SequentialOpticalSystem, sequence: List[str], source, params, num_rays: int = 100000, method_ray_tracing='sobol') List[BoolTensor][source]

Returns a boolean mask identifying which parameters are unused in the ray tracing process.

Parameters:
  • optical_system (SequentialOpticalSystem) – Full system.

  • sequence (list) – Ordered module names.

  • source – Ray source.

  • params (list[torch.nn.Parameter]) – Parameter list.

  • num_rays (int) – Number of rays to test.

  • method_ray_tracing (str) – Sampling method.

Returns:

Masks of the same shape as each parameter.

Return type:

list[torch.BoolTensor]

diffinytrace.element.set_used_params_bounds_to_constant(optical_system, sequence, source, params, bounds_attr_name_new, bounds_attr_name_old='bounds', num_rays=100000, method_ray_tracing='sobol')[source]

Locks unused parameters by copying their current value as bounds, making them constant.

Parameters:
  • bounds_attr_name_new (str) – Name of the new bounds attribute to write.

  • bounds_attr_name_old (str) – Name of the original bounds attribute.

class diffinytrace.element.FresnelOpticalSurface(transform, surface, aperture_radius, surface_derivative_x, surface_derivative_y, is_square=False)[source]

Bases: OpticalSurface

get_virtual_normals(O)[source]

Computes virtual surface normals for Fresnel surfaces.

Parameters:

O (torch.Tensor) – Positions.

Returns:

Virtual normals.

Return type:

torch.Tensor

class diffinytrace.element.FresnelVirtualLensSurfaceTransmissionEnter(transform, surface, aperture_radius, n_func, surface_derivative_x, surface_derivative_y, is_square=False)[source]

Bases: FresnelOpticalSurface

forward(O1, D1, wl, n_func_enviroment, meta_data)[source]

Propagates rays through the Fresnel lens entry surface.

Parameters:
  • O1 (torch.Tensor) – Ray origins.

  • D1 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning environmental refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, wavelengths, environment function, and metadata.

Return type:

Tuple

class diffinytrace.element.FresnelVirtualLensSurfaceTransmissionLeave(transform, surface, aperture_radius, n_func, surface_derivative_x, surface_derivative_y, is_square=False)[source]

Bases: FresnelOpticalSurface

forward(O2, D2, wl, n_func_enviroment, meta_data)[source]

Propagates rays through the Fresnel lens exit surface.

Parameters:
  • O2 (torch.Tensor) – Ray origins.

  • D2 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning environmental refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, wavelengths, environment function, and metadata.

Return type:

Tuple

class diffinytrace.element.FresnelVirtualLens(transform, lens_thickness, surface1, surface2, n_func, aperture_radius, surface1_derivative_x=None, surface1_derivative_y=None, surface2_derivative_x=None, surface2_derivative_y=None, is_square=False)[source]

Bases: OpticalElement

get_plot_points_2D(resolution: int)[source]

Returns a list of 2D plot points for the object.

Parameters:

resolution (int) – The resolution for the plot points.

Returns:

A list of 2D plot points.

Return type:

list

get_plot_points_3D(resolution)[source]

Returns 3D grid of Fresnel lens surface points for visualization.

Parameters:

resolution (int) – Grid resolution.

Returns:

List of (x, y, z) meshgrids.

Return type:

List[Tuple[torch.Tensor]]

get_plotly_color_scale()[source]

Returns color scale for plotly visualization.

Returns:

Color scale values.

Return type:

List

get_plotable_childs() List[source]

Returns plotable child elements.

Returns:

List of child elements.

Return type:

List

forward(O1: Tensor, D1: Tensor, wl: Tensor, n_func_enviroment, meta_data) Tensor[source]

Simulates light passing through the Fresnel lens.

Parameters:
  • O1 (torch.Tensor) – Ray origin positions.

  • D1 (torch.Tensor) – Ray directions.

  • wl (torch.Tensor) – Wavelengths.

  • n_func_enviroment – Function returning external medium refractive index.

  • meta_data (dict) – Ray metadata.

Returns:

Updated ray origins, directions, etc.

Return type:

Tuple

get_transform() Transform[source]

Returns the transformation of the Fresnel lens exit surface.

Returns:

The transformation object.

Return type:

Transform

diffinytrace.element.compute_reflected_directions(D: Tensor, N: Tensor) Tensor[source]

Computes reflected ray directions using the reflection law.

Parameters:
  • D (torch.Tensor) – Incident directions of shape (M, 3), normalized.

  • N (torch.Tensor) – Surface normals at points of incidence, shape (M, 3).

Returns:

Reflected directions of shape (M, 3).

Return type:

torch.Tensor

diffinytrace.element.get_refracted_directions(D: Tensor, N: Tensor, n1: Tensor | float, n2: Tensor | float) Tensor[source]

Computes refracted ray directions using Snell’s law.

At material interfaces, the transmitted direction \(\mathbf{D'}\) is computed based on the surface normal \(\mathbf{N} = \nabla s / \|\nabla s\|\) and the incident direction \(\mathbf{D}\), using Snell’s law (see [WCH22]):

\[\mathbf{D'} = \mathbf{N} \sqrt{1 - (1 - \cos^2 \psi_i) \eta^2} + \eta (\mathbf{D} - \mathbf{N} \cos \psi_i),\]

where \(\cos \psi_i = \mathbf{D} \cdot \mathbf{N}\) and \(\eta = n / n'\) is the ratio of the refractive indices of the two materials.

Parameters:
  • D (torch.Tensor) – Incident directions of shape (M, 3), normalized.

  • N (torch.Tensor) – Surface normals at points of incidence, shape (M, 3).

  • n1 (float or torch.Tensor) – Refractive index of the incident medium.

  • n2 (float or torch.Tensor) – Refractive index of the transmission medium.

Returns:

Refracted directions of shape (M, 3).

Return type:

torch.Tensor

diffinytrace.element.set_unused_bspline_coeff_to_nearest(optical_system, sequence: list[str], source, bspline_surface, num_rays=100000, method_ray_tracing='sobol')[source]

Fills only the unused B-spline coefficients with the nearest used value.

This function identifies B-spline coefficients that have no influence on the ray paths (i.e., gradients are zero), and updates only those by copying the value from the closest neighboring coefficient that is used. Used coefficients remain unchanged.

This is useful for having geometry that is simple to manifacture while not tempering with the overall performance.

Parameters:
  • optical_system (SequentialOpticalSystem) – The optical system used for tracing.

  • sequence (list[str]) – Ordered list of module names for ray propagation.

  • source – Ray source with a .sample() method.

  • bspline_surface – Surface object with a .coeff tensor.

  • num_rays (int, optional) – Number of rays used to detect unused coefficients. Default is 100000.

  • method_ray_tracing (str, optional) – Sampling method (e.g., “sobol”). Default is “sobol”.

Raises:

RuntimeError – If all coefficients are unused — likely due to insufficient ray coverage.