Source code for diffinytrace.target_grid

"""
This module implements grid-based spatial aggregation for ray optics.

Classes:
    - Grid: Represents a 2D grid for spatial aggregation and statistics.
    - GridSquare: Square variant of Grid for symmetric apertures.

Functions:
    - (none at top level)

Example:
    >>> grid = Grid([0, 1], [0, 1], 10, 10)
    >>> area = grid.get_area()
"""

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.

__all__ = [
    "Grid",
    "GridSquare"
]



import torch
import torch.nn as nn
from sklearn.neighbors import NearestNeighbors
import numpy as np

[docs] class Grid(): """ Represents a 2D grid over a rectangular area with aggregation and indexing utilities. Args: y_range (tuple[float, float]): The range in y-direction, as (y_min, y_max). x_range (tuple[float, float]): The range in x-direction, as (x_min, x_max). y_grid_size (int): Number of grid cells in y-direction. x_grid_size (int): Number of grid cells in x-direction. """ def __init__(self,y_range,x_range,y_grid_size,x_grid_size): super().__init__() self.y_range = np.array(y_range) self.x_range = np.array(x_range) self.x_grid_size = x_grid_size self.y_grid_size = y_grid_size self.x_delta = (self.x_range[1]-self.x_range[0])/x_grid_size self.y_delta = (self.y_range[1]-self.y_range[0])/y_grid_size
[docs] def get_area(self): r""" Computes the total area of the grid. Returns: float: Total area of the grid. .. math:: A = (x_{max} - x_{min}) \cdot (y_{max} - y_{min}) """ return (self.x_range[1]-self.x_range[0])*(self.y_range[1]-self.y_range[0])
[docs] def get_pixel_area(self): r""" Returns the area of a single pixel/grid cell. Returns: float: Area of a single grid cell. .. math:: A_{pixel} = \Delta x \cdot \Delta y """ return self.x_delta*self.y_delta
[docs] def get_yi_xi(self,local_points,round_to_bounds=True): r""" Converts 2D local coordinates to integer grid indices. Args: local_points (torch.Tensor): Tensor of shape (N, 2) representing 2D points. round_to_bounds (bool): If True, clamps indices to stay within grid bounds. If False, returns a mask indicating valid indices. Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors (yi, xi) of shape (N,). """ if len(local_points.shape) != 2 or local_points.shape[1] != 2: raise RuntimeError("The local_points must be in local coordinates and of shape [#points,2]") local_points = local_points.detach() ref_x = (local_points[:,0]-self.x_range[0])/self.x_delta ref_y = (local_points[:,1]-self.y_range[0])/self.y_delta xi = torch.floor(ref_x).long() yi = torch.floor(ref_y).long() valid = (xi>=self.x_grid_size).float()+(xi<0).float()+(yi>=self.y_grid_size).float()+(yi<0).float() valid = valid==0.0 if ((xi>=self.x_grid_size).any() or (xi<0).any() or (yi>=self.y_grid_size).any() or (yi<0).any()): yi = torch.clamp(yi,min=0,max=(self.y_grid_size-1)) xi = torch.clamp(xi,min=0,max=(self.x_grid_size-1)) #else: # raise RuntimeError(f"Target grid ERROR: points out of bounds! max xi={xi.max()}, min xi={xi.min()},max yi={yi.max()}, min yi={yi.min()}") if round_to_bounds: return (yi,xi) else: return (yi,xi),valid
[docs] def get_k(self,local_points,round_to_bounds=True): r""" Maps local coordinates to flattened grid indices. Args: local_points (torch.Tensor): Tensor of shape (N, 2). round_to_bounds (bool): Whether to clamp indices to grid bounds. Returns: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - If `round_to_bounds` is True: Tensor of shape (N,). - Otherwise: Tuple (indices, validity_mask). """ if round_to_bounds: yi,xi = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds) return (yi*self.x_grid_size+xi).long() else: (yi,xi),valid = self.get_yi_xi(local_points,round_to_bounds=round_to_bounds) k = (yi*self.x_grid_size+xi).long() return k,valid
[docs] def map_matrix_to_ray(self,local_points,old_matrix): r""" Maps a matrix defined on the grid to the given local points. Args: local_points (torch.Tensor): Points of shape (N, 2). old_matrix (torch.Tensor): Matrix of shape (H, W, ...). Returns: torch.Tensor: Resampled matrix values of shape (N, ...). """ device = local_points.device dtype = local_points.dtype k = self.get_k(local_points) return old_matrix.reshape(-1)[k].reshape(local_points.shape[0],*old_matrix.shape[2:])
[docs] def sum(self, local_points:torch.Tensor, values:torch.Tensor, old_matrix = None, round_to_bounds:bool = False): """ Sums values over the grid based on point locations. Args: local_points (torch.Tensor): Points of shape (N, 2). values (torch.Tensor): Values of shape (N,) or (N, D). old_matrix (torch.Tensor or None): Previous result for accumulation. round_to_bounds (bool): Clamp indices to bounds if True. Returns: torch.Tensor: Aggregated result of shape (H, W). """ device = local_points.device dtype = local_points.dtype out = torch.zeros((self.x_grid_size*self.y_grid_size),device=device,dtype=dtype) if not old_matrix is None: out = old_matrix if round_to_bounds: k = self.get_k(local_points,round_to_bounds) out.scatter_add_(0,k,values) out = out.reshape(self.y_grid_size,self.x_grid_size) return out else: k,valid = self.get_k(local_points,round_to_bounds) values = values[valid] k = k[valid] out.scatter_add_(0,k,values) out = out.reshape(self.y_grid_size,self.x_grid_size) return out
[docs] def prod(self,local_points,values,old_matrix = None,round_to_bounds=False): """ Multiplies values over the grid based on point locations. Args: local_points (torch.Tensor): Points of shape (N, 2). values (torch.Tensor): Values of shape (N,) or (N, D). old_matrix (torch.Tensor or None): Previous result for accumulation. round_to_bounds (bool): Clamp indices to bounds if True. Returns: torch.Tensor: Aggregated result of shape (H, W). """ device = local_points.device dtype = local_points.dtype out = torch.ones((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype) if not old_matrix is None: out = old_matrix if round_to_bounds: k = self.get_k(local_points,round_to_bounds) out.scatter_reduce_(0,k,values,reduce='prod') out = out.reshape(self.y_grid_size,self.x_grid_size) return out else: k,valid = self.get_k(local_points,round_to_bounds) values = values[valid] k = k[valid] out.scatter_reduce_(0,k,values,"prod") out = out.reshape(self.y_grid_size,self.x_grid_size) return out
[docs] def mean(self,local_points,values,old_matrix = None,round_to_bounds=False): """ Computes the mean of values over the grid based on point locations. Args: local_points (torch.Tensor): Points of shape (N, 2). values (torch.Tensor): Values of shape (N,) or (N, D). old_matrix (torch.Tensor or None): Previous result for accumulation. round_to_bounds (bool): Clamp indices to bounds if True. Returns: torch.Tensor: Aggregated result of shape (H, W). """ device = local_points.device dtype = local_points.dtype out = torch.zeros((self.y_grid_size*self.x_grid_size),device=device,dtype=dtype) if not old_matrix is None: out = old_matrix if round_to_bounds: k = self.get_k(local_points) out.scatter_reduce_(0,k,values,reduce='mean',include_self=False) out = out.reshape(self.y_grid_size,self.x_grid_size) return out else: k,valid = self.get_k(local_points,round_to_bounds) values = values[valid] k = k[valid] out.scatter_reduce_(0,k,values,reduce='mean',include_self=False) out = out.reshape(self.y_grid_size,self.x_grid_size) return out
def __get_args(self,M,b,v): device = v.device dtype = v.dtype M_argmin = torch.full((self.y_grid_size*self.x_grid_size,), -1, dtype=torch.long,device=device) mask = (v == M[b]) indices = torch.arange(len(v)) M_argmin.scatter_(0, b[mask], indices[mask]) return M_argmin
[docs] def min(self,local_points,values,old_matrix = None,return_args=False): """Finds the minimum value for each grid cell based on local points. Args: local_points (torch.Tensor): Points of shape (N, 2). values (torch.Tensor): Values of shape (N,) or (N, D). old_matrix (torch.Tensor or None): Previous result for accumulation. return_args (bool): If True, also return indices. Returns: torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum values, optionally with indices. """ device = local_points.device dtype = local_points.dtype out = torch.full((self.y_grid_size*self.x_grid_size,),float('inf'),device=device,dtype=dtype) if not old_matrix is None: out = old_matrix k = self.get_k(local_points) out.scatter_reduce_(0,k,values,reduce='amin') if return_args: out_args = self.__get_args(out,k,values) out = out.reshape(self.y_grid_size,self.x_grid_size) out_args = out_args.reshape(self.y_grid_size,self.x_grid_size) return out,out_args else: out = out.reshape(self.y_grid_size,self.x_grid_size) return out
[docs] def max(self,local_points,values,old_matrix = None,return_args=False): """ Finds the maximum value for each grid cell based on local points. Args: local_points (torch.Tensor): Points of shape (N, 2). values (torch.Tensor): Values of shape (N,) or (N, D). old_matrix (torch.Tensor or None): Previous result for accumulation. return_args (bool): If True, also return indices. Returns: torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Maximum values, optionally with indices. """ device = local_points.device dtype = local_points.dtype out = torch.full((self.y_grid_size*self.x_grid_size,),float('-inf'),device=device,dtype=dtype) if not old_matrix is None: out = old_matrix k = self.get_k(local_points) out.scatter_reduce_(0,k,values,reduce='amax') if return_args: out_args = self.__get_args(out,k,values) out = out.reshape(self.y_grid_size,self.x_grid_size) out_args = out_args.reshape(self.y_grid_size,self.x_grid_size) return out,out_args else: out = out.reshape(self.y_grid_size,self.x_grid_size) return out
def __get_x_middle(self): x_middle = self.x_delta*0.5+torch.arange(0,self.x_grid_size)*self.x_delta+self.x_range[0] return x_middle def __get_y_middle(self): y_middle = self.y_delta*0.5+torch.arange(0,self.y_grid_size)*self.y_delta+self.y_range[0] return y_middle
[docs] def get_y_middle(self): return self.__get_y_middle()
[docs] def get_x_middle(self): return self.__get_x_middle()
[docs] def nearest(self,local_points,return_args=False): """ Finds the nearest pixel for each local point using L2 distance. Args: local_points (torch.Tensor): Tensor of shape (N, 2). return_args (bool): If True, also return indices. Returns: torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: Minimum squared distances, optionally with indices. """ x_middle = self.__get_x_middle() y_middle = self.__get_y_middle() yi,xi = self.get_yi_xi(local_points) xdiff = (x_middle[xi]-local_points[:,0])**2.0 ydiff = (y_middle[yi]-local_points[:,1])**2.0 l2diff = xdiff+ydiff return self.min(local_points,l2diff,return_args=return_args)
[docs] def get_pixel_centers(self): """ Returns the 2D center coordinates of each grid cell. Returns: torch.Tensor: Tensor of shape (H, W, 2) with pixel center coordinates. """ x_middle = self.__get_x_middle() y_middle = self.__get_y_middle() grid_y,grid_x = torch.meshgrid(y_middle, x_middle, indexing='ij') V = torch.cat([grid_x.reshape(-1,1),grid_y.reshape(-1,1)],dim=-1) return V.reshape(self.y_grid_size,self.x_grid_size,2)
[docs] def get_nearest_ray(self,local_points): """ Finds the index of the nearest ray for each grid cell using `sklearn.neighbors.NearestNeighbors`. Args: local_points (torch.Tensor): Tensor of shape (N, 2) representing sampled rays. Returns: torch.Tensor: Tensor of shape (H, W) with ray indices. """ device = local_points.device dtype = local_points.dtype local_points = local_points.detach() with torch.no_grad(): W = local_points V = self.get_pixel_centers().reshape(-1,2) nn_model = NearestNeighbors(n_neighbors=1, algorithm='kd_tree') nn_model.fit(W) # Fit the model on W (the smaller collection) distances, indices = nn_model.kneighbors(V) indices = indices.flatten() #out = flat_args[indices].reshape(self.y_grid_size,self.x_grid_size) out = torch.tensor(indices.reshape(self.y_grid_size,self.x_grid_size),device = device) return out
#implment nearest_ray
[docs] class GridSquare(Grid): """ Convenience class for square grids centered at the origin. Args: aperture_radius (float): Half-width of the square domain. grid_size (int): Number of grid points in each direction. """ def __init__(self,aperture_radius,grid_size): super().__init__(\ [-aperture_radius,aperture_radius],\ [-aperture_radius,aperture_radius],grid_size,grid_size)