Source code for stgraph.graph.static.static_graph

"""Represent Static graphs in STGraph."""

from __future__ import annotations

import copy

import numpy as np
from rich.console import Console

from stgraph.graph.static.csr import CSR
from stgraph.graph.stgraph_base import STGraphBase

console = Console()


[docs]class StaticGraph(STGraphBase): r"""Represent Static graphs in STGraph. This abstract class outlines the interface for defining a static graphs used in STGraph. As of now the static graph is implemented using the Compressed Sparse Row (CSR) format. Example: ------- .. code-block:: python from stgraph.graph import StaticGraph from stgraph.dataset import HungaryCPDataLoader hungary = HungaryCPDataLoader() graph = StaticGraph( edge_list = hungary.get_edges(), edge_weights = hungary.get_edge_weights(), num_nodes = hungary.gdata["num_nodes"] ) """ def __init__( self: StaticGraph, edge_list: list, edge_weights: list, num_nodes: int, ) -> None: r"""Represent Static graphs in STGraph.""" super().__init__() self._num_nodes = num_nodes self._num_edges = len(set(edge_list)) self._prepare_edge_lst_fwd(edge_list) self._forward_graph = CSR( self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True, ) self._prepare_edge_lst_bwd(self.fwd_edge_list) self._backward_graph = CSR(self.bwd_edge_list, edge_weights, self._num_nodes) self._get_graph_csr_ptrs() # TODO-DOCS: def _prepare_edge_lst_fwd(self: STGraphBase, edge_list: list) -> None: edge_list_for_t = edge_list edge_list_for_t.sort(key=lambda x: (x[1], x[0])) edge_list_for_t = [ (edge_list_for_t[j][0], edge_list_for_t[j][1], j) for j in range(len(edge_list_for_t)) ] self.fwd_edge_list = edge_list_for_t # TODO-DOCS @nithin: def _prepare_edge_lst_bwd(self: STGraphBase, edge_list: list) -> None: edge_list_for_t = copy.deepcopy(edge_list) edge_list_for_t.sort() self.bwd_edge_list = edge_list_for_t # TODO-DOCS @nithin: def _get_graph_csr_ptrs(self: STGraphBase) -> None: self.fwd_row_offset_ptr = self._forward_graph.row_offset_ptr self.fwd_column_indices_ptr = self._forward_graph.column_indices_ptr self.fwd_eids_ptr = self._forward_graph.eids_ptr self.fwd_node_ids_ptr = self._forward_graph.node_ids_ptr self.bwd_row_offset_ptr = self._backward_graph.row_offset_ptr self.bwd_column_indices_ptr = self._backward_graph.column_indices_ptr self.bwd_eids_ptr = self._backward_graph.eids_ptr self.bwd_node_ids_ptr = self._backward_graph.node_ids_ptr
[docs] def get_num_nodes(self: STGraphBase) -> int: r"""Return the number of nodes in the static graph.""" return self._num_nodes
[docs] def get_num_edges(self: STGraphBase) -> int: r"""Return the number of edges in the static graph.""" return self._num_edges
[docs] def get_ndata(self: STGraphBase, field: any) -> any: r"""Return the graph metadata.""" if field in self._ndata: return self._ndata[field] return None
[docs] def set_ndata(self: STGraphBase, field: str, val: any) -> None: r"""Set the graph metadata.""" self._ndata[field] = val
[docs] def graph_type(self: STGraphBase) -> str: r"""Return the graph type.""" return "csr_unsorted"
[docs] def in_degrees(self: STGraphBase) -> np.ndarray: r"""Return the graph inwards node degree array.""" return np.array(self._forward_graph.out_degrees, dtype="int32")
[docs] def out_degrees(self: STGraphBase) -> np.ndarray: r"""Return the graph outwards node degree array.""" return np.array(self._forward_graph.in_degrees, dtype="int32")
# TODO-DOCS @nithin:
[docs] def weighted_in_degrees(self: STGraphBase) -> np.ndarray: r"""weighted_in_degrees.""" return np.array(self._forward_graph.weighted_out_degrees, dtype="int32")