Source code for stgraph.nn.pytorch.temporal.tgcn

import torch
from stgraph.nn.pytorch.static.gcn_conv import GCNConv

[docs]class TGCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super(TGCN, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv_z = GCNConv(self.in_channels, self.out_channels, activation=None) self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels) self.conv_r = GCNConv(self.in_channels, self.out_channels, activation=None) self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels) self.conv_h = GCNConv(self.in_channels, self.out_channels, activation=None) self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels) def _set_hidden_state(self, X, H): if H is None: H = torch.zeros(X.shape[0], self.out_channels).to(X.device) return H def _calculate_update_gate(self, g, X, edge_weight, H): h = self.conv_z(g, X, edge_weight=edge_weight) h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values Z = torch.cat((h, H), dim=1) Z = self.linear_z(Z) Z = torch.sigmoid(Z) return Z def _calculate_reset_gate(self, g, X, edge_weight, H): h = self.conv_r(g, X, edge_weight=edge_weight) h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values R = torch.cat((h, H), dim=1) R = self.linear_r(R) R = torch.sigmoid(R) return R def _calculate_candidate_state(self, g, X, edge_weight, H, R): h = self.conv_h(g, X, edge_weight=edge_weight) h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values H_tilde = torch.cat((h, H * R), dim=1) H_tilde = self.linear_h(H_tilde) H_tilde = torch.tanh(H_tilde) return H_tilde def _calculate_hidden_state(self, Z, H, H_tilde): H = Z * H + (1 - Z) * H_tilde return H
[docs] def forward(self, g, X, edge_weight=None, H=None): H = self._set_hidden_state(X, H) Z = self._calculate_update_gate(g, X, edge_weight, H) R = self._calculate_reset_gate(g, X, edge_weight, H) H_tilde = self._calculate_candidate_state(g, X, edge_weight, H, R) H = self._calculate_hidden_state(Z, H, H_tilde) return H