Source code for diffinytrace.intersection

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.

__all__ = [
    "SemiFunctionalModule",
    "cat_semi_functionals",
    "get_functional_param_args",
    "construct_surface_and_normal_func",
    "construct_surface_and_normal_func_with_params",
    "CustomAutogradRule_t",
    "get_ray_intersection_length"
]

import torch
import torch.nn as nn
from .utils.autograd import grad
from .config import get_max_iterations,get_tolerance,get_damping_factor,get_show_iteration_count
from typing import List,Tuple,Callable,Optional

[docs] class SemiFunctionalModule(nn.Module): r""" Abstract base class for semi-functional surface modules. These modules define a static method `functional` that computes a functional transformation on inputs and parameters, and a method to list their functional parameters for optimization purposes. """ def __init__(self): super().__init__()
[docs] @staticmethod def functional(O:torch.Tensor,*params): r""" This method provides the implicit surface description. It is a static method. Diffinytrace constructs a function `s(R, p)` on the fly to describe the surface, allowing better control over derivative calculations. """ raise NotImplementedError("functional not implemented")
[docs] def get_functional_param_args(self): raise NotImplementedError("params_list not implemented")
[docs] def cat_semi_functionals(functional_modules:List[SemiFunctionalModule])->Callable: r""" Recursively chains a list of `SemiFunctionalModule`s into a single composite function. Each module's `functional()` method is applied in sequence using the respective slice of the parameter list. Args: functional_modules (list[SemiFunctionalModule]): List of functional modules. Returns: Callable: A function f(O, *params) that applies all modules in sequence. """ if len(functional_modules) == 0: return lambda O,*params: O current_func = functional_modules[0].functional other = functional_modules[1:] num_params = len(functional_modules[0].get_functional_param_args()) def fun_out(O,*params): other_funs = cat_semi_functionals(other) return other_funs(current_func(O,*params[:num_params]),*params[num_params:]) return fun_out
[docs] def get_functional_param_args(semi_functional_module_list:List[SemiFunctionalModule])->List: r""" Collects all functional parameters from a list of semi-functional modules. Args: semi_functional_module_list (list[SemiFunctionalModule]): List of modules. Returns: list[torch.nn.Parameter]: Flattened list of all parameters. """ out = [] for elem in semi_functional_module_list: out += elem.get_functional_param_args() return out
[docs] def construct_surface_and_normal_func(semi_functional_module_list:List[SemiFunctionalModule]) -> Callable: r""" Constructs a function to evaluate both the surface value and its gradient (normal direction) with respect to the ray origin `O`. The surface is defined by composing the provided semi-functional modules. Returns a callable: .. math:: (O, p_1, ..., p_n) \mapsto ( s(O), \frac{\partial s}{\partial O} ) Args: semi_functional_module_list (list[SemiFunctionalModule]): List of modules. Returns: Callable: A function `s_dsd(O, *params, only_s=False)` returning surface value `s` and optionally gradient `ds/dO`. """ s = cat_semi_functionals(semi_functional_module_list) def s_dsd(O,*params,only_s = False): sval,dsdval= None,None with torch.enable_grad(): if not O.requires_grad: O.requires_grad = True sval = s(O,*params) if only_s: return sval dsdval = grad(sval,inputs=O,grad_outputs=torch.ones_like(sval)) dsdval = dsdval[0] return sval,dsdval return s_dsd
[docs] def construct_surface_and_normal_func_with_params(semi_functional_module_list:List[SemiFunctionalModule]) -> Tuple[Callable, List]: r""" Constructs both the surface function and a list of its functional parameters. Useful for optimization workflows that require parameter tracking. Args: semi_functional_module_list (list[SemiFunctionalModule]): List of modules. Returns: tuple: Callable: A function computing surface and its gradient. list[torch.nn.Parameter]: The list of parameters for the surface. """ s_dsd = construct_surface_and_normal_func(semi_functional_module_list) args = get_functional_param_args(semi_functional_module_list) return s_dsd,args
[docs] class CustomAutogradRule_t(torch.autograd.Function): """ Custom PyTorch autograd rule for ray-surface intersection. Computes a differentiable intersection length `t` such that: .. math:: s(O + t D) = 0 where `O` is the ray origin, `D` is the direction, and `s` is the surface function. This rule enables backpropagation through `t` with respect to `O`, `D`, and surface parameters. """
[docs] @staticmethod def forward(ctx, O:torch.Tensor, D:torch.Tensor, surface_and_normal_func:Callable, t_detached:torch.Tensor, *param_args) -> torch.Tensor: """ Stores inputs for backward pass and returns precomputed `t`. Args: O (torch.Tensor): Ray origin of shape (N, 3). D (torch.Tensor): Ray direction of shape (N, 3). surface_and_normal_func (Callable): Surface function returning (s, ds/dR). t_detached (torch.Tensor): Estimated intersection length (detached). *param_args: Surface parameters. Returns: torch.Tensor: Intersection length `t`. """ ctx.save_for_backward(O,D,t_detached,*param_args) ctx.surface_and_normal_func = surface_and_normal_func return t_detached
[docs] @staticmethod def backward(ctx, grad_outputs:torch.Tensor)->Tuple: """ Computes gradients of intersection length `t` with respect to: - ray origin `O` - ray direction `D` - surface parameters Args: grad_outputs (torch.Tensor): Gradient of the loss w.r.t. output `t`. Returns: tuple: Gradients with respect to inputs (O, D, None, None, *param_args). """ saved_tensors = ctx.saved_tensors O = saved_tensors[0] D = saved_tensors[1] t_detached = saved_tensors[2] param_args = saved_tensors[3:] surface_and_normal_func = ctx.surface_and_normal_func t = CustomAutogradRule_t.apply(O,D,surface_and_normal_func,t_detached,*param_args) R = O+t*D param_args_clone = [] for elem in param_args: if torch.is_tensor(elem): elem = elem.clone() param_args_clone.append(elem) s_val,dsdR_val = surface_and_normal_func(R,*param_args_clone) dsdR_T_D = torch.sum(dsdR_val*D,axis=-1) v1 = -grad_outputs.reshape(-1)/dsdR_T_D.reshape(-1) jact_dtdp = None with torch.enable_grad(): s_val = [s_val.reshape(-1)] jact_dtdp = grad(s_val,[*param_args_clone], grad_outputs=v1,create_graph=True,retain_graph=True) jact_dtdO = v1.reshape(-1,1)*dsdR_val jact_dtdD = jact_dtdO*t.reshape(-1,1) return jact_dtdO,jact_dtdD,None,None,*jact_dtdp
[docs] def get_ray_intersection_length(O:torch.Tensor, D:torch.Tensor, surface_and_normal_func:Callable, param_args:List, t_init:Optional[torch.Tensor]=None)->torch.Tensor: """ Solves for the intersection length `t` such that: .. math:: s(O + t D) = 0 using a Newton-style iteration method with damping. This function finds the length `t` where a ray intersects a parametric surface, given by a composed function with normal information. Args: O (torch.Tensor): Ray origins of shape (N, 3). D (torch.Tensor): Ray directions of shape (N, 3). surface_and_normal_func (Callable): A function returning (s, ds/dR). param_args (list): List of surface parameters. t_init (torch.Tensor, optional): Initial guess for `t`. If None, starts from zero. Returns: torch.Tensor: Estimated intersection lengths `t` with autograd support. Raises: Warning is printed (not exception) if convergence fails within `max_iter`. """ tolerance = get_tolerance() max_iter = get_max_iterations() damping = get_damping_factor() device = O.device dtype = O.dtype N = O.shape[0] #better initial value t_detached = None if t_init is not None: t_detached = t_init.detach().reshape(N,1) else: t_detached = torch.zeros((N,1),device=device,dtype=dtype) O_detached = O.detach() D_detached = D.detach() converged = False smax_vals = [] for k in range(max_iter): R_detached = O_detached+t_detached*D_detached s_val,dsdR_val = surface_and_normal_func(R_detached,*param_args) s_val,dsdR_val = s_val.detach(),dsdR_val.detach() t_detached = t_detached-damping*s_val.reshape(-1,1)/(torch.sum(dsdR_val*D_detached,dim=-1).reshape(-1,1)) t_detached = t_detached.detach() smax_vals += [torch.max(torch.abs(s_val.detach()))] if (s_val<tolerance).all(): converged = True if get_show_iteration_count(): print(f"Ray intersection with surface completed in {k} iterations.") break if not converged: print(f"Ray intersection FAILED to converge after {max_iter} iterations!\nThis is totally normal durring optimization when a bad parameterset is chosen."+"maximum svals are: "+str(smax_vals)) t_out = CustomAutogradRule_t.apply(O,D,surface_and_normal_func,t_detached,*param_args) return t_out