HEALPix Neural Networks#

FOSCAT provides neural-network architectures designed to operate directly on HEALPix geometry using oriented spherical convolutions as the spatial primitive. Unlike standard convolutions applied to equirectangular projections, FOSCAT’s convolutions are defined on the sphere and do not introduce projection distortions.

This workflow corresponds to CNN_local.ipynb and CNN_ecmwf.ipynb in demo-foscat-pangeo-eosc.


Spatial primitive: oriented spherical convolution#

All neural-network modules in FOSCAT use SphericalStencil.Convol_torch as the convolution primitive. For each HEALPix pixel \(p\), a \(K \times K\) stencil of neighbour pixel indices is looked up from a precomputed table; the convolution is then a dot product between the kernel weights and the neighbour values.

The stencil tables are the same ones used by FoCUS for the scattering operator, and are cached in ~/.FOSCAT/data/. This means the first call at a given (nside, KERNELSZ) pair triggers a one-time initialisation.


HealpixUNet#

Module: foscat.healpix_unet_torch

A U-Net-style encoder–decoder on the HEALPix sphere. Each level applies two oriented spherical convolutions with BatchNorm and ReLU, then downsamples or upsamples via ud_grade_2 / up_grade. Skip connections carry encoder features to the decoder at matching resolutions.

Architecture#

Input  [B, C_in, N_cells]     N_cells = len(cell_ids)

  enc[0]  DoubleConv  C_in     → chanlist[0]
  down[0] ud_grade_2           N_cells → N_cells/4
  enc[1]  DoubleConv  chanlist[0] → chanlist[1]
  down[1] ud_grade_2
  ...
  enc[L]  DoubleConv  chanlist[L-1] → chanlist[L]   (bottleneck)

  up[0]   up_grade             N_cells/4^L → N_cells/4^(L-1)
  dec[0]  DoubleConv  chanlist[L]+chanlist[L-1] → chanlist[L-1]  (concat skip)
  ...
  dec[L-1] DoubleConv chanlist[1]+chanlist[0] → chanlist[0]

  out_conv  Conv1d(chanlist[0], out_channels, 1)  → [B, out_channels, N_cells]

Output [B, out_channels, N_cells]

Each “DoubleConv” block: SphericalConv BN ReLU SphericalConv BN ReLU.

Constructor#

from foscat.healpix_unet_torch import HealpixUNet

model = HealpixUNet(
    in_nside        = 64,
    n_chan_in        = 1,
    chanlist         = [16, 32, 64],
    cell_ids         = cell_ids,
    KERNELSZ         = 3,
    gauge_type       = "phi",
    G                = 1,
    task             = "regression",
    out_channels     = 1,
    final_activation = None,
    device           = None,
    down_type        = "max",
)

Parameters#

Parameter

Type

Default

Description

in_nside

int

HEALPix nside for the input domain (NESTED).

n_chan_in

int

Number of input channels.

chanlist

list[int]

Channel count at each encoder level, e.g. [16, 32, 64]. Depth = len(chanlist).

cell_ids

ndarray

NESTED pixel indices of the regional domain at in_nside. Shape (N_cells,).

KERNELSZ

int

3

Stencil side length \(K\) (\(K \times K\) neighbours per pixel).

gauge_type

str

"phi"

Orientation convention for the gauge frame. "phi": aligned with the longitude direction (preferred for Earth-observation). "cosmo": standard cosmological convention.

G

int

1

Number of gauge orientations. G > 1 increases intermediate channel count by \(G\).

task

str

"regression"

Output head: "regression" (no final activation by default) or "segmentation" (softmax / sigmoid).

out_channels

int

1

Number of output channels (e.g. number of classes for segmentation).

final_activation

str|None

None

Override the default activation: "none", "sigmoid", or "softmax".

device

str|device|None

auto

Target device. Defaults to CUDA if available; falls back to CPU if FOSCAT ops cannot run on CUDA.

down_type

str

"max"

Downsampling strategy: "max" (max over 4 NESTED children) or "mean" (average).

prefer_foscat_gpu

bool

True

Try CUDA for FOSCAT ops and fall back to CPU if a dry-run fails.

Methods#

forward(x) Tensor

# x: (B, C_in, N_cells)
y = model(x)    # (B, out_channels, N_cells)

fit(x_train, y_train, ...) dict

history = model.fit(
    x_train,           # (N, C_in, N_cells)
    y_train,           # (N, out_channels, N_cells)
    x_val   = None,
    y_val   = None,
    n_epoch     = 100,
    batch_size  = 16,
    lr          = 1e-3,
    weight_decay = 1e-5,
    view_epoch  = 10,
    loss_fn     = None,    # defaults to F.mse_loss
)
# returns {"train_loss": [...], "val_loss": [...]}

predict(x, batch_size=16) ndarray

Batched inference without gradient. Returns (N, out_channels, N_cells) on CPU.

Example — regression on a regional domain#

import numpy as np
import torch
from foscat.healpix_unet_torch import HealpixUNet
import healpy as hp

nside = 64
cell_ids = hp.query_disc(nside, hp.ang2vec(np.pi/2, 0), np.radians(30), nest=True)

model = HealpixUNet(
    in_nside=nside, n_chan_in=6, chanlist=[16, 32, 64],
    cell_ids=cell_ids, KERNELSZ=3, out_channels=1,
)

# Training data: X_atm shape (N, 6, N_cells), Y_sst shape (N, 1, N_cells)
history = model.fit(X_atm, Y_sst, x_val=X_val, y_val=Y_val,
                    n_epoch=100, lr=1e-3)

sst_pred = model.predict(X_test)   # (N_test, 1, N_cells)

GCNN — graph-convolutional neural network#

Module: foscat.GCNN

A graph-convolutional network that uses the FOSCAT scattering operator as the convolution layer. Suitable for regression tasks over the full sphere or a regional subset.

from foscat.GCNN import GCNN

model = GCNN(
    nparam   = 1,       # number of output scalars per pixel
    KERNELSZ = 3,
    NORIENT  = 4,
    chanlist = [1, 16, 32, 16, 1],
    in_nside = 64,
)

Parameter

Type

Default

Description

nparam

int

1

Number of output channels per pixel.

KERNELSZ

int

3

Wavelet stencil side length.

NORIENT

int

4

Number of wavelet orientations.

chanlist

list[int]

[]

Channel sizes at each layer. The network depth equals len(chanlist) - 1.

in_nside

int

1

Input HEALPix nside.


Choosing between architectures#

Architecture

Best for

Notes

HealpixUNet

Dense prediction, regression, segmentation

Full U-Net with skip connections; most flexible

GCNN

Global scalar regression, lightweight models

Scattering-based graph conv; fewer parameters

CNN

Simple per-pixel classification

Flat CNN; use as baseline

healpix_vit_torch.HealpixViT

Long-range dependencies, global patterns

Vision Transformer on HEALPix tokens


Hybrid scattering + neural workflows#

FOSCAT’s neural networks are differentiable and can be embedded in synthesis or component separation loops. For example, a trained HealpixUNet can serve as a learned morphological prior:

from foscat.Synthesis import Loss, Synthesis

# trained_model: a HealpixUNet that maps noise → clean field

def neural_prior_loss(x, scat_op, args):
    model = args[0]
    x_clean = model(torch.tensor(x).unsqueeze(0).unsqueeze(0).float())
    return scat_op.backend.bk_mean((x - x_clean.squeeze().numpy()) ** 2)

loss = Loss(neural_prior_loss, scat_op, trained_model)
solver = Synthesis([loss])
result = solver.run(noisy_map, NUM_EPOCHS=200)