Source code for diffinytrace.transforms

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.

__all__ = [
    "Transform",
    "Identity",
    "Compose",
    "Offset",
    "Distance",
    "Rotation",
    "rotation_matrix_x",
    "rotation_matrix_y",
    "rotation_matrix_z"
]

import torch
import torch.nn as nn
from .intersection import SemiFunctionalModule,cat_semi_functionals
import numpy as np
from .optimize import make_parameter_from_input  

[docs] class Transform(SemiFunctionalModule): """ Base class for coordinate transformations. This class provides interfaces to transform directions and positions between local and global coordinate systems using homogeneous coordinates. Methods: get_functional_param_args(): Return parameters required for the transformation. functional(O, *params): Apply transformation in functional style. get_transformation_matrix(): Return the 4x4 transformation matrix. to_global_dir(direction): Transform direction to global space. to_local_dir(direction): Transform direction to local space. to_global_pos(position): Transform position to global space. to_local_pos(position): Transform position to local space. """ def __init__(self): super().__init__()
[docs] def get_functional_param_args(self): """ Return parameters required for the transformation which constructs the surfaces through the functional. Returns: list: List of parameters required for the functional which constructs the surfaces. """ raise NotImplementedError("params_list not implemented")
[docs] @staticmethod def functional(O,*params)->torch.Tensor: """ Apply transformation in functional style. This is global to local. Args: O (torch.Tensor): Input tensor to be transformed. *params: Parameters for the transformation. """ raise NotImplementedError("functional not implemented")
[docs] def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor: """ Return the 4x4 transformation matrix. Args: device (torch.device, optional): Device for the matrix. dtype (torch.dtype, optional): Data type for the matrix. Returns: torch.Tensor: 4x4 transformation matrix. """ raise NotImplementedError("get_transformation_matrix not implemented")
[docs] def get_transform(self): """ Returns itself. """ return self
[docs] def to_global_dir(self,direction:torch.Tensor) -> torch.Tensor: """ Transform direction to global space. Args: direction (torch.Tensor): Direction vector in local space. Returns: torch.Tensor: Direction vector in global space. """ M = self.get_transformation_matrix(direction.device,direction.dtype) R = M[np.ix_([0,1,2],[0,1,2])] out = direction@R.T return out
[docs] def to_local_dir(self,direction:torch.Tensor) -> torch.Tensor: """ Transform direction to local space. Args: direction (torch.Tensor): Direction vector in global space. Returns: torch.Tensor: Direction vector in local space. """ M = self.get_transformation_matrix(direction.device,direction.dtype) R = M[np.ix_([0,1,2],[0,1,2])] R_inv = torch.inverse(R) out = direction@R_inv.T return out
[docs] def to_global_pos(self,position:torch.Tensor) -> torch.Tensor: """ Transform position to global space. Args: position (torch.Tensor): Position vector in local space. Returns: torch.Tensor: Position vector in global space. """ M = self.get_transformation_matrix(position.device,position.dtype) v = torch.zeros((position.shape[0],4),device=position.device,dtype=position.dtype) v[:,[0,1,2]] = position v[:,3] = torch.ones_like(v[:,3]) _out = v@M.T out = _out[:,[0,1,2]] return out
[docs] def to_local_pos(self,position:torch.Tensor) -> torch.Tensor: """ Transform position to local space. Args: position (torch.Tensor): Position vector in global space. Returns: torch.Tensor: Position vector in local space. """ return self.functional(position,*self.get_functional_param_args())
[docs] class Identity(Transform): """ Identity transformation that returns input positions unchanged. Example: >>> import diffinytrace as dit >>> transf1 = dit.transforms.Identity() """ def __init__(self): super().__init__()
[docs] def get_functional_param_args(self): return []
[docs] @staticmethod def functional(O:torch.Tensor) -> torch.Tensor: return O
[docs] def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor: out = torch.eye(4,device=device,dtype=dtype) return out
[docs] class Compose(Transform): """ Compose multiple transforms in sequence. Args: transform_list (list[Transform]): List of transformations to apply in order. """ def __init__(self,transform_list): super().__init__() self.transform_list = nn.ModuleList(transform_list) self.functional = cat_semi_functionals(self.transform_list)
[docs] def get_functional_param_args(self): out = [] for elem in self.transform_list: out += elem.get_functional_param_args() return out
[docs] def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor: out = torch.eye(4,device=device,dtype=dtype) for elem in self.transform_list: tmp = elem.get_transformation_matrix(device,dtype) if not device is None and tmp.device != device: tmp = tmp.to(device) if not dtype is None and tmp.dtype != dtype: tmp = tmp.to(dtype) if out.device != tmp.device: tmp = tmp.to(device) if out.dtype != tmp.dtype: tmp = tmp.to(dtype) out = out @ tmp return out
[docs] class Offset(Transform): r""" Translation transform using an offset vector. The offset transformation shifts a position by a specified vector \( \vec{w} = (w_x, w_y, w_z) \). The transformation matrix \( M \) for an offset transformation is: .. math:: M^{offset}(w_x, w_y, w_z) = \begin{bmatrix} 1 & 0 & 0 & w_x \\ 0 & 1 & 0 & w_y \\ 0 & 0 & 1 & w_z \\ 0 & 0 & 0 & 1 \end{bmatrix} Example: >>> import diffinytrace as dit >>> transf1 = dit.transforms.Identity() >>> transf2 = dit.transforms.Offset([1.0, 2.0, 3.0], parent_transform=transf1) Args: pos (Tensor or list or float): The offset position as a 3D vector. parent_transform (Transform, optional): Optional parent transformation. """ def __init__(self,pos,parent_transform=Identity()): super().__init__() self.pos = make_parameter_from_input(pos) self.parent_transform = parent_transform.get_transform()
[docs] def get_functional_param_args(self): return [self.pos]+self.parent_transform.get_functional_param_args()
[docs] def functional(self,O:torch.Tensor,pos:torch.Tensor,*parent_param_args)->torch.Tensor: O = self.parent_transform.functional(O,*parent_param_args) return O-pos
[docs] def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor: if device is None: device = self.pos.device if dtype is None: dtype = self.pos.dtype parent_transform_matrix = self.parent_transform.get_transformation_matrix(device=device,dtype=dtype) this_matrix = torch.eye(4,device=device,dtype=dtype) this_matrix[[0,1,2],-1] = self.pos.to(device=device,dtype=dtype) out = parent_transform_matrix@this_matrix return out
[docs] class Distance(Transform): r""" Applies a translation along a specific axis by a given distance. The distance transformation applies a translation by a specific distance along a given axis (e.g., \( x \)-, \( y \)-, or \( z \)-axis). The transformation matrix \( M \) for a distance transformation along the \( z \)-axis is given by: .. math:: M^{dist}_z(d) = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & d \\ 0 & 0 & 0 & 1 \end{bmatrix}, where \( d \) represents the distance of translation along the \( z \)-axis. Args: distance (float or Tensor): Distance to translate. axis (int): Axis along which translation is applied (0=X, 1=Y, 2=Z). parent_transform (Transform): Optional parent transformation. Example: >>> import diffinytrace as dit >>> transf1 = dit.transforms.Identity() >>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1) Notes: For the local to global transformation it applies the following transformation: .. math:: \mathbf{x}_\text{local} = \mathbf{x}_\text{parent} + d \cdot \mathbf{e}_i """ def __init__(self,distance,axis = 2,parent_transform=Identity()): super().__init__() self.distance = make_parameter_from_input(distance) self.unit_vec = torch.tensor([0.,0.,0.]) #self.register_buffer('unit_vec', torch.tensor([0.,0.,0.])) # Buffer attribute self.unit_vec[axis] = 1.0 #is constant! self.parent_transform = parent_transform.get_transform()
[docs] def get_functional_param_args(self): unit_vec = self.unit_vec if unit_vec.device != self.distance.device: unit_vec = unit_vec.to(self.distance.device) return [self.distance,unit_vec]+self.parent_transform.get_functional_param_args()
[docs] def functional(self,O:torch.Tensor,distance:torch.Tensor,unit_vec:torch.Tensor,*parent_param_args)->torch.Tensor: O = self.parent_transform.functional(O,*parent_param_args) O = O-distance*unit_vec return O
[docs] def get_transformation_matrix(self,device=None,dtype=None)->torch.Tensor: if device is None: device = self.distance.device if dtype is None: dtype = self.distance.dtype unit_vec = self.unit_vec.to(device=device,dtype=dtype) parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype) this_matrix = torch.eye(4,device=device,dtype=dtype) this_matrix[[0,1,2],-1] = self.distance.to(device=device,dtype=dtype)*unit_vec out = parent_transform_matrix@this_matrix return out
[docs] def rotation_matrix_x(angle:torch.Tensor) -> torch.Tensor: """ Construct a 3x3 rotation matrix around the X-axis. Args: angle (Tensor): Angle in degrees. Returns: Tensor: 3x3 rotation matrix. """ # Convert angle from degrees to radians angle = angle * (2.0 * torch.pi / 360.0) device = angle.device dtype = angle.dtype # Initialize a 4x4 identity matrix rot_x = torch.eye(3, dtype=dtype, device=device) # Set the rotation entries rot_x[1, 1] = torch.cos(angle) rot_x[1, 2] = -torch.sin(angle) rot_x[2, 1] = torch.sin(angle) rot_x[2, 2] = torch.cos(angle) return rot_x
[docs] def rotation_matrix_y(angle:torch.Tensor) -> torch.Tensor: """ Construct a 3x3 rotation matrix around the Y-axis. Args: angle (Tensor): Angle in degrees. Returns: Tensor: 3x3 rotation matrix. """ # Convert angle from degrees to radians angle = angle * (2.0 * torch.pi / 360.0) device = angle.device dtype = angle.dtype # Initialize a 4x4 identity matrix rot_y = torch.eye(3, dtype=dtype, device=device) # Set the rotation entries rot_y[0, 0] = torch.cos(angle) rot_y[0, 2] = torch.sin(angle) rot_y[2, 0] = -torch.sin(angle) rot_y[2, 2] = torch.cos(angle) return rot_y
[docs] def rotation_matrix_z(angle:torch.Tensor) -> torch.Tensor: """ Construct a 3x3 rotation matrix around the Z-axis. Args: angle (Tensor): Angle in degrees. Returns: Tensor: 3x3 rotation matrix. """ # Convert angle from degrees to radians angle = angle * (2.0 * torch.pi / 360.0) device = angle.device dtype = angle.dtype # Initialize a 4x4 identity matrix rot_z = torch.eye(3, dtype=dtype, device=device) # Set the rotation entries rot_z[0, 0] = torch.cos(angle) rot_z[0, 1] = -torch.sin(angle) rot_z[1, 0] = torch.sin(angle) rot_z[1, 1] = torch.cos(angle) return rot_z
[docs] class Rotation(Transform): r""" Applies a 3D rotation around a principal axis. The rotational transformation rotates a point or direction around a specific axis (e.g., \( x \)-, \( y \)-, and \( z \)-axis). For example, the rotation matrix around the \( z \)-axis is: .. math:: M^{rot}_z(\theta_z) = \begin{bmatrix} \cos \theta_z & -\sin \theta_z & 0 & 0 \\ \sin \theta_z & \cos \theta_z & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} Args: angle (float or Tensor): Rotation angle in degrees. axis (int): Axis index (0=X, 1=Y, 2=Z). parent_transform (Transform, optional): Optional parent transformation. Example: >>> import diffinytrace as dit >>> transf1 = dit.transforms.Identity() >>> transf2 = dit.transforms.Distance(10.0,axis=2,parent_transform=transf1) >>> transf3 = dit.transforms.Rotation(45.,axis=0,parent_transform=transf2) """ def __init__(self, angle: float, axis: int, parent_transform=Identity()): #TODO test rotation for combi angle_x, angle_y, angle_z Reihenfolge egal? super().__init__() self.angle = make_parameter_from_input(angle) self.axis = axis self.parent_transform = parent_transform.get_transform()
[docs] def get_functional_param_args(self): return [self.angle]+self.parent_transform.get_functional_param_args()
[docs] def functional(self,O:torch.Tensor,angle:torch.Tensor,*parent_param_args)->torch.Tensor: #R = rotate_3d(angle_x, angle_y, angle_z) O = self.parent_transform.functional(O,*parent_param_args) R = None if self.axis == 0: R = rotation_matrix_x(360.0-angle) if self.axis == 1: R = rotation_matrix_y(360.0-angle) if self.axis == 2: R = rotation_matrix_z(360.0-angle) return O@R.T
[docs] def get_transformation_matrix(self,device=None,dtype=None) -> torch.Tensor: if device is None: device = self.angle.device if dtype is None: dtype = self.angle.dtype parent_transform_matrix = self.parent_transform.get_transformation_matrix(device,dtype) R = None if self.axis == 0: R = rotation_matrix_x(self.angle) if self.axis == 1: R = rotation_matrix_y(self.angle) if self.axis == 2: R = rotation_matrix_z(self.angle) if R.device != device: R = R.to(device) if R.dtype != dtype: R = R.to(dtype) this_matrix = torch.eye(4,device=device,dtype=dtype) this_matrix[:3,:3] = R out = parent_transform_matrix@this_matrix return out