Skip to content

Graph Operator Models API

deepuq.models.graph_operator contains grid-as-graph neural operators for scientific fields. The public GraphNeuralOperator2D accepts channels-last regular-grid tensors, converts them to a local graph internally, and ends in a final nn.Linear head so last-layer Laplace remains compatible.

Typical shape contract:

  • input: [batch, height, width, channels]
  • output: [batch, height, width, out_channels]

The model appends normalized coordinates (x, y) to each node feature and uses radius-based local neighborhoods to pass messages.

deepuq.models.graph_operator

Graph-based neural operators for regular-grid scientific fields.

The models in this module treat a Cartesian lattice as a graph and perform message passing directly on node embeddings. They are intended as pragmatic operator-learning baselines for settings where a graph inductive bias is useful or where a future migration to unstructured meshes is anticipated.

GraphNeuralOperator2D

Bases: Module

Message-passing neural operator on a 2D grid treated as a graph.

Parameters:

Name Type Description Default
in_channels int

Number of input field channels. In the Gray-Scott notebook this is 2 for the species fields A and B.

required
hidden_dim int

Node embedding width.

64
message_dim int

Width of the edge-message hidden state.

64
n_message_passing_steps int

Number of residual message-passing blocks.

4
out_channels int

Number of predicted output channels.

2
radius int

Grid-neighborhood radius. 1 corresponds to an 8-neighbor stencil.

1
use_edge_mlp bool

If True, use a learned edge MLP that mixes source, destination, and edge features. Otherwise use a lighter linear message path.

True
Notes

The public forward accepts channels-last grid tensors with shape [batch, H, W, C] and returns channels-last outputs with shape [batch, H, W, out_channels]. The final decoder ends in nn.Linear so last-layer Laplace remains compatible.

forward

forward(x: Tensor) -> torch.Tensor

Map a channels-last 2D input field to a channels-last output field.