# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.
__all__ = [
"PlotableWavelength",
"add_colour_bar",
"plot"
]
import matplotlib.pyplot as plt
import numpy as np
import torch
from colour import XYZ_to_sRGB,wavelength_to_XYZ
from typing import Tuple,Optional,Union
[docs]
class PlotableWavelength:
"""
Represents a wavelength range and y-axis label for plotting spectral data.
Attributes:
bounds (tuple): Lower and upper bounds for the wavelength range.
ylabel (str): Label for the y-axis in plots.
"""
def __init__(self, bounds: Tuple[float, float], ylabel: str):
self.bounds = bounds
self.ylabel = ylabel
[docs]
def add_colour_bar(fig, ax, wl):
"""
Add a color strip below the plot to represent the wavelength spectrum.
Args:
fig (matplotlib.figure.Figure): The figure object.
ax (matplotlib.axes.Axes): The main axis of the plot.
wl (array-like): Wavelengths in µm.
"""
left, bottom, width, height = ax.get_position().bounds
color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
def wavelength_to_rgb(wl):
wl = wl*1000.
if 360.0 < wl and wl < 780.0:
rgb = XYZ_to_sRGB(wavelength_to_XYZ(wl))
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
else:
return (0.,0.,0.)
colors = [wavelength_to_rgb(_wl) for _wl in wl]
for i in range(len(wl) - 1):
color_ax.fill_between([wl[i], wl[i + 1]], 0, 1, color=colors[i])
color_ax.set_xlim(np.min(wl),np.max(wl))
color_ax.axis('off') # Hide axis for a clean color strip
#TODO change bmin and bmax to bounds
#refractive_index
[docs]
def plot(wl,vals=None,title="",xlabel="Wavelength [µm]",ylabel="y",labels=None,colour_bar=True,linewidth=2,legend=True,resolution=500,show=True):
"""
Plot a spectrum with a color strip below it.
Args:
wl (array-like): Wavelengths in nm or µm.
vals (array-like): Values of the spectrum at the given wavelengths.
title (str): Title of the plot.
xlabel (str): Label for the x-axis.
ylabel (str): Label for the y-axis.
labels (list): Labels for the different curves.
colour_bar (bool): Whether to show a color bar.
linewidth (int): Line width of the plot.
legend (bool): Whether to show a legend.
resolution (int): Resolution of the plot.
show (bool): Whether to show the plot.
Returns:
None
"""
if vals is None:
if not isinstance(wl,PlotableWavelength):
raise RuntimeError("if vals=None, wl must be a PlotableWavelength!")
plotable_func = wl
wl = np.linspace(*plotable_func.bounds,resolution)
vals = plotable_func(wl)
if ylabel=="y":
ylabel = plotable_func.ylabel
# Create figure and main axis
fig, ax = plt.subplots(figsize=(10, 5))
plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
vals = np.array(vals)
wl = np.array(wl)
if (wl>100.).any():
print("wl is µm not nm! Setting wl to µm")
wl = wl/1000.
vmin = np.min(vals)
vmax = np.max(vals)
if len(vals.shape) == 1:
val = vals
vmin = np.min(val)
ax.plot(wl, val, color='black', linewidth=linewidth)
ax.fill_between(wl, val, color='gray', alpha=0.2)
else:
if vals.shape[1] != wl.shape[0]:
vals = vals.T
for i in range(len(vals)):
val = vals[i]
label = None
if labels is not None:
label = labels[i]
ax.plot(wl, val,label=label, linewidth=linewidth)
ax.set_xlim(np.min(wl),np.max(wl))
ax.set_ylim(vmin,vmax+(vmax-vmin)*0.1)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
if colour_bar:
add_colour_bar(fig, ax, wl)
if labels is not None:
if legend:
ax.legend(loc='upper right')
if show:
plt.show()
"""
import matplotlib.pyplot as plt
import numpy as np
import colour
import matplotlib.pyplot as plt
import numpy as np
import colour
# Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
bmin = 300.
bmax = 3000.
wavelengths = np.linspace(bmin, bmax, 1000)
spectrum = np.exp(-((wavelengths - 550) / 40) ** 2) # Gaussian curve centered at 550 nm
# Function to map wavelength to RGB color
def wavelength_to_rgb(wavelength):
if 360.0 < wavelength and wavelength < 780.0:
rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
else:
return (0.,0.,0.)
def add_color_strip(ax, bmin, bmax, resolution=1000):
wavelengths = np.linspace(bmin, bmax, resolution)
colors = [wavelength_to_rgb(wl) for wl in wavelengths]
# Create the color strip as a series of filled segments
for i in range(len(wavelengths) - 1):
ax.fill_between([wavelengths[i], wavelengths[i + 1]], 0, 1, color=colors[i])
ax.set_xlim(bmin, bmax)
ax.axis('off') # Hide axis for a clean color strip
# Create figure and main axis
fig, ax = plt.subplots(figsize=(10, 5))
plt.subplots_adjust(bottom=0.3) # Increase space at the bottom
# Plot the spectrum curve
ax.plot(wavelengths, spectrum, color='black', linewidth=2)
ax.fill_between(wavelengths, spectrum, color='gray', alpha=0.2)
# Add labels and limits for the main plot
ax.set_xlim(bmin, bmax)
ax.set_ylim(0, 1.1)
ax.set_xlabel("Wavelength (nm)")
ax.set_ylabel("Intensity (a.u.)")
ax.set_title("Spectrum with Corresponding Colors")
# Add an extra axis for the color strip below the main plot, with extra spacing
left, bottom, width, height = ax.get_position().bounds
color_ax = fig.add_axes([left, bottom - 0.15, width, 0.03]) # Position further below to avoid overlap
add_color_strip(color_ax, bmin, bmax, resolution=1000)
plt.show()
#%%
"""
"""
#%%
import matplotlib.pyplot as plt
import numpy as np
import colour
import matplotlib.pyplot as plt
import numpy as np
import colour
# Define the wavelength range (in nm) and the spectrum curve (e.g., Gaussian example)
wavelengths = np.linspace(380, 780, 1000)
spectrum = pvlib.spectrum.get_am15g(wavelengths)
spectrum = np.array(spectrum)
# Function to map wavelength to RGB color
def wavelength_to_rgb(wavelength):
rgb = colour.XYZ_to_sRGB(colour.wavelength_to_XYZ(wavelength))
return np.clip(rgb, 0.0, 1.0) # Ensure RGB values are within [0, 1]
colors = [wavelength_to_rgb(wl) for wl in wavelengths]
# Create the color strip as a series of filled segments
fig, ax = plt.subplots(figsize=(10, 5))
for i in range(len(wavelengths) - 1):
plt.fill_between([wavelengths[i], wavelengths[i + 1]], 0, spectrum[i+1], color=colors[i],interpolate=True)
# Plot the spectrum curve
plt.xlim(380, 750)
plt.ylim(0., 2.0)
plt.plot(wavelengths, spectrum, color='black')
plt.show()
"""