Source code for diffinytrace.utils.autograd

# Copyright (c) 2025 Martin Pflaum
# This file is part of the diffinytrace project, licensed under the MIT License.


__all__ = [
    "grad"
]

import typing
import torch


[docs] def grad( outputs: torch.types._TensorOrTensors, inputs: torch.types._TensorOrTensorsOrGradEdge, grad_outputs: typing.Optional[torch.types._TensorOrTensors] = None, retain_graph: typing.Optional[bool] = True, create_graph: bool = True, only_inputs: bool = True, is_grads_batched: bool = False, materialize_grads: bool = False, remove_no_grad_outputs: bool = True ): """ Computes the gradients of the outputs with respect to the inputs. Args: outputs (torch.Tensor or tuple of torch.Tensor): The output tensors. inputs (torch.Tensor or tuple of torch.Tensor): The input tensors. grad_outputs (torch.Tensor or tuple of torch.Tensor, optional): The gradients of the outputs. retain_graph (bool, optional): Whether to retain the graph after computing gradients. create_graph (bool, optional): Whether to create the graph for higher-order gradients. only_inputs (bool, optional): Whether to only compute gradients for the inputs. is_grads_batched (bool, optional): Whether the gradients are batched. materialize_grads (bool, optional): Whether to materialize the gradients. remove_no_grad_outputs (bool, optional): Whether to remove outputs that do not require gradients. Returns: list: A list of gradients for each input tensor. """ if torch.is_tensor(inputs): inputs = [inputs] inputs = [elem for elem in inputs] if remove_no_grad_outputs: if torch.is_tensor(grad_outputs) or torch.is_tensor(outputs): if torch.is_tensor(outputs): if not outputs.requires_grad: if torch.is_tensor(inputs): raise RuntimeError("this branch should not be called!") else: out = [] for elem in inputs: if torch.is_tensor(elem): out += [torch.zeros_like(elem)] else: out += [None] return out else: _grad_outputs = [elem for elem in grad_outputs] new_grad_outputs = [] new_outputs = [] for k,elem in enumerate(outputs): if torch.is_tensor(elem): if elem.requires_grad: grad_elem = _grad_outputs[k] new_outputs += [elem] new_grad_outputs += [grad_elem] grad_outputs = new_grad_outputs outputs = new_outputs inputs_requires_grad = [] back_map_input = {} inputs_map_i = 0 for k,param in enumerate(inputs): #param = inputs[k] if param is None: continue if param.requires_grad: inputs_requires_grad += [param] back_map_input[k] = inputs_map_i inputs_map_i += 1 grad_tmp = [] if len(inputs_requires_grad)!=0: grad_tmp = torch.autograd.grad(outputs=outputs, inputs=inputs_requires_grad, grad_outputs=grad_outputs, retain_graph=retain_graph, create_graph=create_graph, only_inputs=only_inputs, allow_unused=True, is_grads_batched=is_grads_batched, materialize_grads=materialize_grads) else: pass grad = [None for input in inputs] for k in range(len(grad)): if k in back_map_input.keys(): inputs_map_i = back_map_input[k] grad[k] = grad_tmp[inputs_map_i] else: if materialize_grads: if inputs[k] is None: grad[k] = None else: grad[k] = torch.zeros_like(inputs[k]) else: grad[k] = None return grad