Source code for diffinytrace.spectrum

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.

__all__ = [
    "Spectrum",
    "VisibleSunlight_am15g"
]

import torch
import torch.nn as nn
import numpy as np
from .plotting.wavelength import PlotableWavelength
from typing import Callable,Union,List,Tuple

[docs] class Spectrum(nn.Module, PlotableWavelength): """ A class to represent a spectrum as a function of wavelength. """ def __init__(self, func: Callable[[torch.Tensor], torch.Tensor], bounds: Tuple[float, float]): """ Initialize the Spectrum class. Args: func (callable): A function that takes a wavelength and returns the spectrum value. bounds (tuple): A tuple containing the minimum and maximum wavelength. """ nn.Module.__init__(self) PlotableWavelength.__init__(self,bounds,"Intensity [1]") self.func = func self.bounds = bounds
[docs] def forward(self, wl: torch.Tensor) -> torch.Tensor: """ Calculate the spectrum for given wavelengths. Args: wl (torch.Tensor or float): Wavelength in μm. Returns: torch.Tensor: Spectrum value at the given wavelengths. """ if not torch.is_tensor(wl): wl = torch.tensor(wl) vmin,vmax = self.bounds out = self.func(wl) if isinstance(out,float): return out*torch.ones_like(wl) if isinstance(out,np.ndarray): out = torch.tensor(out,device=wl.device,dtype=wl.dtype) if (vmin > wl).any(): out[vmin > wl] = 0.0 if (wl>vmax).any(): out[wl>vmax] = 0.0 if torch.is_tensor(out): if len(out.shape) == 0: return out*torch.ones_like(wl) return out
[docs] class VisibleSunlight_am15g(Spectrum): """ A class to represent the AM 1.5 G spectrum. This class uses the pvlib library to calculate the spectrum. """ def __init__(self): from pvlib.spectrum import get_am15g def func(wl): device = wl.device dtype = wl.dtype wl = wl.detach().cpu().numpy() out = get_am15g(wl*1000.) out = np.array(out) out = torch.tensor(out,device=device,dtype=dtype) return out super().__init__(func,[0.360,0.780])