foscat.healpix_unet_torch#

GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.

  • ReLU + BatchNorm after each convolution (encoder & decoder)

  • Segmentation/Regression heads with optional final activation

  • PyTorch-ified: inherits from nn.Module, standard state_dict

  • Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU

Shape convention: (B, C, Npix)

Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)

Classes#

HealpixUNet

U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.

TestUNET

Lightweight smoke tests for shape and parameter plumbing.

HealpixDataset

Fixed-grid dataset (common Npix for all samples).

HealpixDataset

Fixed-grid dataset (common Npix for all samples).

HealpixDataset

Fixed-grid dataset (common Npix for all samples).

VarLenHealpixDataset

x_list: list of (C, Npix_b) tensors or arrays

VarLenHealpixDataset

x_list: list of (C, Npix_b) tensors or arrays

Functions#

varlen_collate(batch)

Collate for variable-length samples: keep lists, do NOT stack.

varlen_collate(batch)

Collate for variable-length samples: keep lists, do NOT stack.

fit(→ dict)

Train helper that supports:

Module Contents#

class foscat.healpix_unet_torch.HealpixUNet(*, in_nside: int, n_chan_in: int, chanlist: List[int], cell_ids: numpy.ndarray, KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Literal['none', 'sigmoid', 'softmax'] | None = None, device: torch.device | str | None = None, prefer_foscat_gpu: bool = True, gauge_type: Literal['cosmo', 'phi'] | None = 'cosmo', G: int = 1, down_type: Literal['mean', 'max'] | None = 'max', dtype: Literal['float32', 'float64'] = 'float32', head_reduce: Literal['mean', 'learned'] = 'mean')[source]#

Bases: torch.nn.Module

U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.

Parameters:
  • in_nside (int) – Input HEALPix nside (nested scheme).

  • n_chan_in (int) – Number of input channels.

  • chanlist (list[int]) – Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].

  • cell_ids (np.ndarray) – Cell indices for the finest resolution (nside = in_nside) in nested scheme.

  • KERNELSZ (int, default 3) – Spatial kernel size K (K x K) for oriented convolution.

  • gauge_type (str) – Type of gauge : ‘cosmo’ use the same definition than

    ‘phi’ is define at the pole, could be better for earth observation not using intensivly the pole

  • G (int, default 1) – Number of gauges for the orientation definition.

  • task ({'regression','segmentation'}, default 'regression') – Chooses the head and default activation.

  • out_channels (int, default 1) – Number of output channels (e.g. num_classes for segmentation).

  • final_activation ({'none','sigmoid','softmax'} | None) – If None, uses sensible default per task: ‘none’ for regression, ‘softmax’ for segmentation (multi-class), ‘sigmoid’ for segmentation when out_channels==1.

  • device (str | torch.device | None, default: 'cuda' if available else 'cpu') – Preferred device. The module will probe whether Foscat ops can run on CUDA; if not, it will fallback to CPU and keep all parameters/buffers on CPU for consistency.

  • down_type – {“mean”,”max”}, default “max”. Equivalent of max poll during down

  • prefer_foscat_gpu (bool, default True) – When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run. If the dry-run fails, everything falls back to CPU.

Notes

  • Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.

  • Downsampling uses foscat ud_grade_2; upsampling uses up_grade.

  • Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via SphericalStencil.Convol_torch.

  • Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.

dtype = 'float32'#
gauge_type = 'cosmo'#
G = 1#
in_nside#
n_chan_in#
chanlist#
KERNELSZ = 3#
task = 'regression'#
out_channels = 1#
prefer_foscat_gpu = True#
device#
f#
l_cell_ids: List[numpy.ndarray]#
hconv_enc: List[foscat.SphericalStencil.SphericalStencil] = []#
hconv_dec: List[foscat.SphericalStencil.SphericalStencil] = []#
enc_w1#
enc_bn1#
enc_w2#
enc_bn2#
enc_nsides#
dec_w1#
dec_bn1#
dec_w2#
dec_bn2#
head_hconv#
head_w#
head_bn#
head_reduce#
runtime_device#
set_device(device: torch.device | str) torch.device[source]#

Request a (re)device; will probe Foscat and return the actual runtime device used.

forward_any(x, cell_ids: numpy.ndarray | None = None)[source]#

If x is a Tensor (B,C,N): standard batched path (requires same N for all). If x is a list of Tensors: variable-length per-sample path, returns a list of outputs.

forward(x: torch.Tensor, cell_ids: numpy.ndarray | None = None) torch.Tensor[source]#

Forward pass.

Parameters:
  • x (torch.Tensor, shape (B, C_in, Npix)) – Input tensor on in_nside grid.

  • cell_ids (np.ndarray (B, Npix) optional, use another cell_ids than the initial one.) – if None use the initial cell_ids.

predict(x: torch.Tensor, batch_size: int = 8, cell_ids: numpy.ndarray | None = None) torch.Tensor[source]#
to_tensor(x)[source]#
to_numpy(x)[source]#
extract_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0)[source]#

Extract raw convolution kernels for a given stage/level/conv.

Parameters:
  • stage ({"encoder", "decoder"}) – Which part of the network to inspect.

  • layer (int) – Pyramid level (0 = finest encoder level / bottommost decoder level).

  • conv (int) – 0 for the first conv at that level, 1 for the second conv.

Returns:

np.ndarray – Array of shape (in_c, out_c, K, K) containing the spatial kernels.

plot_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0, fixed: str = 'in', index: int = 0, max_tiles: int = 16)[source]#

Quick visualization of kernels on a grid using matplotlib.

Parameters:
  • stage ({"encoder", "decoder"}) – Which tower to visualize.

  • layer (int) – Level to visualize.

  • conv (int) – 0 or 1: first or second conv in the level.

  • fixed ({"in", "out"}) – If “in”, show kernels for a fixed input channel across many outputs. If “out”, show kernels for a fixed output channel across many inputs.

  • index (int) – Channel index to fix (according to fixed).

  • max_tiles (int) – Maximum number of tiles to display.

class foscat.healpix_unet_torch.TestUNET(methodName='runTest')#

Bases: unittest.TestCase

Lightweight smoke tests for shape and parameter plumbing.

setUp()#

Hook method for setting up the test fixture before exercising it.

test_forward_shape()#
test_param_roundtrip_and_determinism()#
class foscat.healpix_unet_torch.HealpixDataset(x, y, cell_ids=None, dtype=torch.float32)[source]#

Bases: torch.utils.data.Dataset

Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y). Shapes:

x: (C, Npix) y: (C_out or 1, Npix) cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))

x#
y#
class foscat.healpix_unet_torch.HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: numpy.ndarray | torch.Tensor | None = None, dtype: torch.dtype = torch.float32)[source]#

Bases: torch.utils.data.Dataset

Fixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids).

x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix)

class foscat.healpix_unet_torch.HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: numpy.ndarray | torch.Tensor | None = None, dtype: torch.dtype = torch.float32)[source]#

Bases: torch.utils.data.Dataset

Fixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids).

x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix)

class foscat.healpix_unet_torch.VarLenHealpixDataset(x_list: List[numpy.ndarray | torch.Tensor], y_list: List[numpy.ndarray | torch.Tensor], cids_list: List[numpy.ndarray | torch.Tensor] | None = None, dtype: torch.dtype = torch.float32)[source]#

Bases: torch.utils.data.Dataset

Variable-length per-sample dataset.

x_list[b]: (C, Npix_b) or (1, C, Npix_b) y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)

For multi-class segmentation with CrossEntropyLoss, you may pass class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).

cids_list[b]: (Npix_b,) or None

x#
y#
class foscat.healpix_unet_torch.VarLenHealpixDataset(x_list, y_list, cids_list=None, dtype=torch.float32)[source]#

Bases: torch.utils.data.Dataset

x_list: list of (C, Npix_b) tensors or arrays y_list: list of (C_out or 1, Npix_b) tensors or arrays cids_list: optional list of (Npix_b,) arrays

x#
y#
c = None#
foscat.healpix_unet_torch.varlen_collate(batch)[source]#
foscat.healpix_unet_torch.varlen_collate(batch)[source]#

Collate for variable-length samples: keep lists, do NOT stack. Returns lists: xs, ys, cs (cs can be None).

foscat.healpix_unet_torch.fit(model, x_train: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray], y_train: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray], *, cell_ids_train: numpy.ndarray | torch.Tensor | List[numpy.ndarray | torch.Tensor] | None = None, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, x_valid: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray] = None, y_valid: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray] = None, save_model: bool = False, lr: float = 0.001, weight_decay: float = 0.0, clip_grad_norm: float | None = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM') dict[source]#
Train helper that supports:
  • Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.

  • Variable-length lists: x=[(C,N_b)], y=[…], cell_ids=[(N_b,)], returning per-sample grids.

ADAM: standard minibatch update. LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.

Notes

  • For segmentation with multiple classes, pass integer class targets for y: fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).

  • For regression, pass float targets with the same (C_out, N) channeling.