"""
Zernike polynomial basis functions for optical wavefront representation.
This module provides functions to compute Zernike polynomials, which are commonly used
in optics for describing wavefront aberrations over a circular aperture. The polynomials
are orthogonal over the unit disk and are indexed by radial order (n) and azimuthal
frequency (m).
.. figure:: _static/zernike_plot1.png
:alt: Zernike polynomials visualization
:width: 60%
:align: center
Visualization of Zernike polynomials organized by radial order (rows) and azimuthal frequency (columns).
Example:
Basic usage for computing and visualizing Zernike polynomials:
>>> import torch
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import diffinytrace.basis_functions.zernike as zernike
>>>
>>> # Create unit circle grid
>>> grid_size = 256
>>> x = torch.linspace(-1, 1, grid_size)
>>> y = torch.linspace(-1, 1, grid_size)
>>> X, Y = torch.meshgrid(x, y, indexing='ij')
>>>
>>> # Create mask for unit circle
>>> mask = (X**2 + Y**2) <= 1.0
>>> x_points = X[mask]
>>> y_points = Y[mask]
>>> points = torch.stack([x_points, y_points], dim=1)
>>>
>>> # Evaluate Zernike polynomials
>>> max_n = 6 # Maximum radial degree
>>> basis_values = zernike.basis_function(max_n, points)
>>>
>>> # Group basis functions by radial degree
>>> basis_by_degree = {}
>>> for basis_idx in range(basis_values.shape[1]):
... radial_order = zernike.get_radial_order(basis_idx)
... if radial_order not in basis_by_degree:
... basis_by_degree[radial_order] = []
... basis_by_degree[radial_order].append(basis_idx)
>>>
>>> # Visualize the polynomials
>>> max_cols = max(len(indices) for indices in basis_by_degree.values())
>>> num_rows = len(basis_by_degree)
>>> fig, axes = plt.subplots(num_rows, max_cols, figsize=(3*max_cols, 3*num_rows))
>>>
>>> for row_idx, (radial_order, basis_indices) in enumerate(sorted(basis_by_degree.items())):
... for col_idx, basis_idx in enumerate(basis_indices):
... # Create 2D array with NaN outside unit circle
... tmp = torch.full((grid_size, grid_size), float('nan'))
... tmp[mask] = basis_values[:, basis_idx]
...
... # Plot
... ax = axes[row_idx, col_idx]
... im = ax.imshow(tmp.numpy(), extent=[-1, 1, -1, 1],
... origin='lower', cmap='jet', vmin=-1, vmax=1)
... azimuthal = zernike.get_azimuthal_frequency(basis_idx)
... ax.set_title(f"$Z^{{{azimuthal}}}_{{{radial_order}}}$", fontsize=25)
... ax.set_xticks([])
... ax.set_yticks([])
... ax.set_aspect('equal')
>>>
>>> plt.tight_layout()
>>> plt.show()
Notes:
- Zernike polynomials are only defined for points within the unit circle (r ≤ 1)
- Radial order n determines the number of radial variations
- Azimuthal frequency m determines the angular variations and symmetry
"""
# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.
__all__ = [
"basis_function",
"get_num_basis",
"get_radial_order",
"get_azimuthal_frequency"
]
import torch
import math
def __zernike_calc(n:int, m:int, r_powers: list) -> torch.Tensor:
radial_sum = torch.zeros_like(r_powers[0])
m = abs(m)
for k in range((n - m) // 2 + 1):
coef = math.factorial(n - k) / (
math.factorial(k) * math.factorial((n + m) // 2 - k) * math.factorial((n - m) // 2 - k))
if k%2==1:
coef = -coef
power_idx = n - 2 * k - m
if power_idx < 0:
raise RuntimeError("Potential zero division!")
#tmp = r_powers[abs(power_idx)]
#radial_sum += coef / tmp
if power_idx % 2 == 1:
raise RuntimeError("tried to acces odd power idx!")
radial_sum += coef*r_powers[abs(power_idx)]
return radial_sum
def basis_2D(points: torch.Tensor, max_radial_order: int) -> torch.Tensor:
"""
Compute Zernike polynomials for a given set of points.
Args:
max_radial_order (int): Maximum radial order.
points (torch.Tensor): Tensor of shape (N, 2) containing the x and y coordinates of the points.
Returns:
torch.Tensor: Tensor of shape (N, num_coeffs) containing the Zernike polynomial values.
"""
x = points[:, 0]
y = points[:, 1]
#r = torch.sqrt(x**2 + y**2)
r2 = x**2 + y**2
# Precompute powers of r from r^0 to r^max_radial_order
r_powers = [] #[r ** i for i in range(max_radial_order+1)]
for i in range(max_radial_order+1):
if i%2 == 0:
r_powers += [r2**(i/2.0)]
else:
r_powers += [None]
# List to store Zernike polynomial results
zernike_polynomials = []
# Loop over radial and azimuthal degrees
for n in range(max_radial_order + 1):
for m in range(-n, n + 1, 2): # m must have the same parity as n
#r_m = r_powers[abs(m)] # Precompute r^m for both cos and sin components
if m >= 0:
#TODO Remove weird complex number stuff!
multiplier = torch.real((y + 1j * x)**abs(m))#TODO multiply after zerinke_calc!
#this is to slow!!
zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
else:
multiplier = torch.imag((y + 1j * x)**abs(m))
zernike_polynomials.append(multiplier*__zernike_calc(n, m, r_powers))
# Stack all Zernike polynomials into a single tensor
zernike_basis = torch.stack(zernike_polynomials, dim=1)
return zernike_basis
[docs]
def get_num_basis(max_radial_order:int) -> int:
"""
Calculate the number of basis functions for Zernike polynomials up to a given radial order.
The number of basis functions is given by the formula (n + 1) * (n + 2) / 2.
Args:
max_radial_order (int): Maximum radial order.
Returns:
int: Number of coefficients.
"""
n = max_radial_order+1
return int(n*(n+1) / 2)
[docs]
def get_radial_order(basis_idx:int) -> int:
"""
Calculate the radial degree from the basis function index.
Args:
basis_idx (int): Index of the basis function.
Returns:
int: Radial degree.
"""
basis_idx_runner = basis_idx
num_azimuthal_frequencies = 1
row_idx = 0
while True:
if basis_idx_runner < num_azimuthal_frequencies:
return row_idx
basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
num_azimuthal_frequencies += 1
row_idx += 1
[docs]
def get_azimuthal_frequency(basis_idx:int) -> int:
"""
Calculate the azimuthal frequency from the basis function index.
Args:
basis_idx (int): Index of the basis function.
Returns:
int: Azimuthal frequency (m value).
"""
# First get the radial degree n
row_idx = get_radial_order(basis_idx)
num_azimuthal_frequencies = 1
basis_idx_runner = basis_idx
for k in range(row_idx):
basis_idx_runner = basis_idx_runner-num_azimuthal_frequencies
num_azimuthal_frequencies += 1
x_idx_start = None
if num_azimuthal_frequencies % 2 == 0:
half = num_azimuthal_frequencies // 2
tmp = (half-1)*2+1
x_idx_start = - tmp
else:
half = num_azimuthal_frequencies // 2
tmp = (half)*2
x_idx_start = - tmp
#x_idx_start = - (num_azimuthal_frequencies // 2)
# For radial degree n, we have m values: -n, -n+2, ..., -2, 0, 2, ..., n-2, n
# The azimuthal frequency m follows the pattern:
# pos_in_row = 0 -> m = -n
# pos_in_row = 1 -> m = -n + 2
# ...
# pos_in_row = n -> m = n
return x_idx_start + basis_idx_runner*2