Source code for gfn.gfn

import torch
import numpy as np
from gfn.nn_lookup import NNLookupSciPy, NNLookupFaiss


[docs] class GFN(torch.nn.Linear): r"""The graph feedforward network (GFN) layer from `"GFN: A graph feedforward network for resolution-invariant reduced operator learning in multifidelity applications" <https://doi.org/10.1016/j.cma.2024.117458>`_. The layer is an extension of the standard :class:`torch.nn.Linear`, but with weights and biases are optionally associated to original input :math:`\mathcal{M}^i_o` and output :math:`\mathcal{M}^o_o` graphs. GFN then defines new weights and biases for new input :math:`\mathcal{M}^i_n` and output :math:`\mathcal{M}^o_n` graphs, allowing for resolution-invariant learning. The general GFN equation is given by: .. math:: :nowrap: \begin{equation*} \begin{aligned} \tilde{W}_{i_{{\mathcal{M}^o_{n}}}j_{{\mathcal{M}^i_{n}}}} &= \underset{\forall l_{\mathcal{M}^o_{o}} \text{ s.t } l_{\mathcal{M}^o_{o}} {\leftarrow}\!{\backslash}\!{\rightarrow} i_{\mathcal{M}^o_{n}}}{\operatorname{mean}} \sum_{\forall k_{\mathcal{M}^i_{o}} \text{ s.t } k_{\mathcal{M}^i_{o}} {\leftarrow}\!{\backslash}\!{\rightarrow} j_{\mathcal{M}^i_{n}}} \frac{W_{l_{\mathcal{M}^o_{o}}k_{\mathcal{M}^i_{o}}}}{\lvert \{ h_{\mathcal{M}^i_{n}} \text{ s.t. } k_{\mathcal{M}^i_{o}} ~{\leftarrow}\!{\backslash}\!{\rightarrow}~ h_{\mathcal{M}^i_{n}} \}\rvert}, \\ \tilde{b}^d_{i_{\mathcal{M}^o_{n}}} &= \underset{\forall k_{\mathcal{M}^o_{o}} \text{ s.t } k_{\mathcal{M}^o_{o}} {\leftarrow}\!{\backslash}\!{\rightarrow} i_{\mathcal{M}^o_{n}}}{\operatorname{mean}} {b}^d_{k_{\mathcal{M}^o_{o}}}. \end{aligned} \end{equation*} where: - :math:`\mathcal{M}^i_{o}` is the original input graph, - :math:`\mathcal{M}^o_{o}` is the original output graph, - :math:`\mathcal{M}^i_{n}` is the new input graph, - :math:`\mathcal{M}^o_{n}` is the new output graph, - :math:`W` and :math:`b` are the weights and biases associated to the original graphs, - :math:`\tilde{W}` and :math:`\tilde{b}` are the new weights and biases associated to the new graphs, - :math:`i_{\mathcal{M}_1} {\leftarrow}\!{\backslash}\!{\rightarrow} j_{\mathcal{M}_2}` indicates that either node :math:`i` in graph :math:`\mathcal{M}_1` is the nearest neighbor of node :math:`j` in graph :math:`\mathcal{M}_2` or vice versa. GFN also supports mapping from graphs to vectors or vectors to graphs, in which case the above equation can be applied with :math:`\mathcal{M}^i_{o}=\mathcal{M}^i_{n}` for a vector input or :math:`\mathcal{M}^o_{o}=\mathcal{M}^o_{n}` for a vector output. Args: in_features (int or torch.Tensor): Either a tensor of shape :math:`(N_{\text{in}}, D_{\text{in}})` containing the coordinates of each original input node (graph input) or the size :math:`N_{\text{in}}` of each input (vector input). out_features (int or torch.Tensor): Either a tensor of shape :math:`(N_{\text{out}}, D_{\text{out}})` containing the coordinates of each original output node (graph output) or the size :math:`N_{\text{out}}` of each output (vector output). bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) device (torch.device, optional): The device on which the layer should be allocated. (default: :obj:`None`) dtype (torch.dtype, optional): The data type for the layer's parameters. (default: :obj:`None`) nn_backend (str, optional): The backend to use for nearest neighbor lookup. Can be either :obj:`"scipy"` or :obj:`"faiss"`. (default: :obj:`"scipy"`) """ def __init__( self, in_features, out_features, bias=True, device=None, dtype=None, nn_backend="scipy", ): if type(in_features) is int and type(out_features) is int: print( "Warning: no graphical data provided to GFN layer. Behaves like a standard Linear layer." ) if nn_backend == "scipy": self.lookup = NNLookupSciPy elif nn_backend == "faiss": self.lookup = NNLookupFaiss else: raise ValueError(f"Unknown nn_backend: {nn_backend}") in_features_size = ( in_features.shape[0] if type(in_features) is not int else in_features ) out_features_size = ( out_features.shape[0] if type(out_features) is not int else out_features ) super().__init__(in_features_size, out_features_size, bias, device, dtype) if type(in_features) is not int: self.in_tree = self.lookup(in_features, device=device, dtype=dtype) self.in_graph = in_features else: self.in_tree = None self.in_graph = None if type(out_features) is not int: self.out_tree = self.lookup(out_features, device=device, dtype=dtype) self.out_graph = out_features else: self.out_tree = None self.out_graph = None
[docs] def forward(self, x, in_graph=None, out_graph=None): r""" Runs the forward pass of the module. Args: x (torch.Tensor): The input tensor of shape :math:`(..., N_{\text{in}}^{\prime})`. in_graph (torch.Tensor, optional): The input graph matching the shape of the input tensor :math:`(N_{\text{in}}^{\prime}, D_{\text{in}})`. If :obj:`None`, treats as vector input i.e. assumes no change from the original input graph. (default: :obj:`None`) out_graph (torch.Tensor, optional): The output graph of shape :math:`(N_{\text{out}}^{\prime}, D_{\text{out}})`. If :obj:`None`, treats as vector output i.e. assumes no change to the original output graph. (default: :obj:`None`) Returns: torch.Tensor: The output tensor. Matches the shape of the new output graph if provided. """ if in_graph is None and out_graph is None: return super().forward(x) elif in_graph is not None and self.in_tree is None: raise ValueError( "Input graphical data provided but GFN layer was not initialized with input graphical data." ) elif out_graph is not None and self.out_tree is None: raise ValueError( "Output graphical data provided but GFN layer was not initialized with output graphical data." ) device = x.device weight = self.weight bias = self.bias in_tree = self.in_tree out_tree = self.out_tree original_in_graph = self.in_graph original_out_graph = self.out_graph new_in_graph = in_graph new_out_graph = out_graph # -- ENCODER-style -- if new_in_graph is not None: with torch.no_grad(): new_kd_tree = self.lookup(new_in_graph) new_to_orig_in_inds = in_tree.query(new_in_graph) orig_to_new_in_inds = new_kd_tree.query(original_in_graph) orig_size = original_in_graph.shape[0] new_size = new_in_graph.shape[0] denominator = torch.bincount( torch.tensor(new_to_orig_in_inds, device=device), minlength=orig_size, ).to(device) denominator.requires_grad = False orig_pointing_elsewhere = np.arange(orig_size) orig_pointing_elsewhere = orig_pointing_elsewhere[ orig_pointing_elsewhere != new_to_orig_in_inds[orig_to_new_in_inds] ] if orig_pointing_elsewhere.shape[0] > 0: with torch.no_grad(): index = torch.as_tensor(orig_pointing_elsewhere, device=device) values = torch.ones( orig_pointing_elsewhere.shape[0], dtype=int, requires_grad=False, device=device, ) denominator = denominator.index_add_(0, index, values) scaled_weight = weight / denominator weight = scaled_weight[..., new_to_orig_in_inds] if orig_pointing_elsewhere.shape[0] > 0: index = torch.as_tensor( orig_to_new_in_inds[orig_pointing_elsewhere], device=device ) values = scaled_weight[..., orig_pointing_elsewhere] weight = weight.index_add_(1, index, values) # -- DECODER-style -- if new_out_graph is not None: with torch.no_grad(): new_kd_tree = self.lookup(new_out_graph) new_to_orig_in_inds = out_tree.query(new_out_graph) orig_to_new_in_inds = new_kd_tree.query(original_out_graph) orig_size = original_out_graph.shape[0] new_size = new_out_graph.shape[0] denominator = torch.ones( new_size, device=device, dtype=int, requires_grad=False ) orig_pointing_elsewhere = np.arange(orig_size) orig_pointing_elsewhere = orig_pointing_elsewhere[ orig_pointing_elsewhere != new_to_orig_in_inds[orig_to_new_in_inds] ] new_weight = weight[new_to_orig_in_inds] new_bias = bias[new_to_orig_in_inds] if bias is not None else None if orig_pointing_elsewhere.shape[0] > 0: with torch.no_grad(): index = torch.as_tensor( orig_to_new_in_inds[orig_pointing_elsewhere], device=device ) values = torch.ones( orig_pointing_elsewhere.shape[0], dtype=int, requires_grad=False, device=device, ) denominator = denominator.index_add_(0, index, values) values = weight[orig_pointing_elsewhere] new_weight = new_weight.index_add_(0, index, values) if bias is not None: values = bias[orig_pointing_elsewhere] new_bias = new_bias.index_add_(0, index, values) new_bias = new_bias / denominator new_weight = new_weight / denominator.unsqueeze(1) weight = new_weight bias = new_bias return x @ weight.T + bias if bias is not None else x @ weight.T