Source code for stgraph.nn.pytorch.static.gcn_conv

"""Graph Convolutional Network Layer."""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
    from stgraph.compiler.node import CentralNode
    from stgraph.graph import StaticGraph

import torch
from torch import Tensor, nn

from stgraph.compiler import STGraph
from stgraph.compiler.backend.pytorch.torch_callback import STGraphBackendTorch
from stgraph.utils.constants import SizeConstants


[docs]class GCNConv(nn.Module): r"""Graph Convolutional Network Layer. Vertex-centric implementation for Graph Convolutional Network (GCN) layer as described in `Semi-supervised Classification with Graph Convolutional Networks <https://arxiv.org/abs/1609.02907>`_. A multi-layer GCN model has the following layer-wise propagation rule .. math:: H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) - :math:`H^{(l)}`: Matrix of activations in the :math:`l`-th layer; :math:`H^{(0)} = X` is the input feature matrix. - :math:`\sigma`: Activation function (e.g., ReLU). - :math:`\tilde{A} = A + I_N`: Adjacency matrix of the graph with added self-connections. - :math:`I_N`: Identity matrix. - :math:`\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}`: Degree matrix of :math:`\tilde{A}`. - :math:`W^{(l)}`: Trainable weight matrix for the :math:`l`-th layer. **Vertex-Centric Formula** The vertex-centric implementation can be achieved by aggregating all the features of the neighbouring nodes of the central node .. math:: h^{(l+1)} = \left( \sum_{\text{nb} \in \text{innbs}(v)} \text{nb}_{h^{(l)}} \cdot \text{nb}_{\text{norm}} \cdot \text{weight}_{\text{nb,v}} \right) \cdot v_{\text{norm}} - :math:`h^{(l)}`: Activations of central-node in the :math:`l`-th layer. - :math:`\text{innbs}(v)`: In-neighbours of central-node :math:`v`. - :math:`\text{weight}_{\text{nb,v}}`: Weight of edge from :math:`nb` to :math:`v`. In case no edge weights are present, it is set to 1 - :math:`norm`: Node wise normalization factor, :math:`v_{\text{norm}} = \text{in_degrees(v)}^{-0.5}`. **Node Data** The following node data needs to be set using :class:`StaticGraph.set_ndata <stgraph.graph.static.static_graph.StaticGraph>` before calling the :func:`~stgraph.nn.pytorch.static.gcn_conv.GCNConv.forward` method. +---------------+--------------------------------+---------------------------------------------------------------------------------------------------+ | Node Property | Description | Type | +===============+================================+===================================================================================================+ | norm | Node-wise normalization factor | A PyTorch Tensor of shape (num_nodes, 1), where dim=1 contains the node-wise normalization factor | +---------------+--------------------------------+---------------------------------------------------------------------------------------------------+ Parameters ---------- in_channels : int Size of input sample passed into the layer out_channels : int Size of output sample outputted by the layer activation : optional Non-linear activation function provided by `PyTorch <https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity>`_ bias : bool, optional If set to *True*, learnable bias parameters are added to the layer """ def __init__( self: GCNConv, in_channels: int, out_channels: int, activation: Callable[..., torch.Tensor] | None = None, bias: bool = True, ) -> None: """Graph Convolutional Network Layer.""" super().__init__() self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels)) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.bias = None self.activation = activation self.stgraph = STGraph(STGraphBackendTorch()) self.reset_parameters()
[docs] def reset_parameters(self: GCNConv) -> None: r"""Reset the learnable weight and bias parameters. The weight parameter is initialized using a Xavier Uniform distribution. The bias parameter is initialized by setting all values to zero. """ nn.init.xavier_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias)
[docs] def forward( self: GCNConv, graph: StaticGraph, h: Tensor, edge_weight: Tensor | None = None, ) -> Tensor: r"""Execute a single forward pass for the GCN layer. Runs a single forward pass using the vertex-centric implementation of the GCN layer. Parameters ---------- graph : StaticGraph A StaticGraph graph object h : Tensor Input for the GCN forward pass edge_weight : Tensor, optional Edge weights for each edge in the graph Returns ------- Tensor The output after executing the GCN forward pass Raises ------ KeyError If ``norm`` n_data is not present for the graph ValueError If ``norm`` n_data passed is not of the shape (num_nodes, 1) Example ------- Example usage:: # Defining a method to run forward pass with multiple GCN layers def forward(input: Tensor, layers: List[GCNConv], graph: StaticGraph): h = input for layer in layers: h = layer.forward(graph, h) return h """ if graph.get_ndata("norm") is None: raise KeyError("StaticGraph passed to GCNConv forward pass does not contain 'norm' node data") if (len(graph.get_ndata("norm").shape) != SizeConstants.NODE_NORM_SIZE.value or graph.get_ndata("norm").shape[1] != 1 or graph.get_ndata("norm").shape[0] != graph.get_num_nodes()): raise ValueError("Node data 'norm' passed to GCNConv should be of shape (num_nodes, 1)") h = torch.mm(h, self.weight) if edge_weight is None: @self.stgraph.compile(gnn_module=self) def nb_compute(v: CentralNode) -> Tensor: return sum([nb.h * nb.norm for nb in v.innbs]) * v.norm h = nb_compute(g=graph, n_feats={"norm": graph.get_ndata("norm"), "h": h}) else: @self.stgraph.compile(gnn_module=self) def nb_compute(v: CentralNode) -> Tensor: return sum( [ nb_edge.src.norm * nb_edge.src.h * nb_edge.edge_weight for nb_edge in v.inedges ], ) * v.norm h = nb_compute( g=graph, n_feats={"norm": graph.get_ndata("norm"), "h": h}, e_feats={"edge_weight": edge_weight}, ) # bias if self.bias is not None: h = h + self.bias if self.activation: h = self.activation(h) return h