foscat.healpix_unet_torch#
GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.
ReLU + BatchNorm after each convolution (encoder & decoder)
Segmentation/Regression heads with optional final activation
PyTorch-ified: inherits from nn.Module, standard state_dict
Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU
Shape convention: (B, C, Npix)
Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)
Classes#
U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions. |
|
Lightweight smoke tests for shape and parameter plumbing. |
|
Fixed-grid dataset (common Npix for all samples). |
|
Fixed-grid dataset (common Npix for all samples). |
|
Fixed-grid dataset (common Npix for all samples). |
|
x_list: list of (C, Npix_b) tensors or arrays |
|
x_list: list of (C, Npix_b) tensors or arrays |
Functions#
|
Collate for variable-length samples: keep lists, do NOT stack. |
|
Collate for variable-length samples: keep lists, do NOT stack. |
|
Train helper that supports: |
Module Contents#
- class foscat.healpix_unet_torch.HealpixUNet(*, in_nside: int, n_chan_in: int, chanlist: List[int], cell_ids: numpy.ndarray, KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Literal['none', 'sigmoid', 'softmax'] | None = None, device: torch.device | str | None = None, prefer_foscat_gpu: bool = True, gauge_type: Literal['cosmo', 'phi'] | None = 'cosmo', G: int = 1, down_type: Literal['mean', 'max'] | None = 'max', dtype: Literal['float32', 'float64'] = 'float32', head_reduce: Literal['mean', 'learned'] = 'mean')[source]#
Bases:
torch.nn.ModuleU-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
- Parameters:
in_nside (
int) – Input HEALPix nside (nested scheme).n_chan_in (
int) – Number of input channels.chanlist (
list[int]) – Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].cell_ids (
np.ndarray) – Cell indices for the finest resolution (nside = in_nside) in nested scheme.KERNELSZ (
int, default3) – Spatial kernel size K (K x K) for oriented convolution.gauge_type (
str) – Type of gauge : ‘cosmo’ use the same definition than‘phi’ is define at the pole, could be better for earth observation not using intensivly the pole
G (
int, default1) – Number of gauges for the orientation definition.task (
{'regression','segmentation'}, default'regression') – Chooses the head and default activation.out_channels (
int, default1) – Number of output channels (e.g. num_classes for segmentation).final_activation (
{'none','sigmoid','softmax'}| None) – If None, uses sensible default per task: ‘none’ for regression, ‘softmax’ for segmentation (multi-class), ‘sigmoid’ for segmentation when out_channels==1.device (
str | torch.device | None, default:'cuda'if available else'cpu') – Preferred device. The module will probe whether Foscat ops can run on CUDA; if not, it will fallback to CPU and keep all parameters/buffers on CPU for consistency.down_type – {“mean”,”max”}, default “max”. Equivalent of max poll during down
prefer_foscat_gpu (
bool, defaultTrue) – When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run. If the dry-run fails, everything falls back to CPU.
Notes
Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
Downsampling uses foscat
ud_grade_2; upsampling usesup_grade.Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via
SphericalStencil.Convol_torch.Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
- dtype = 'float32'#
- gauge_type = 'cosmo'#
- G = 1#
- in_nside#
- n_chan_in#
- chanlist#
- KERNELSZ = 3#
- task = 'regression'#
- out_channels = 1#
- prefer_foscat_gpu = True#
- device#
- f#
- l_cell_ids: List[numpy.ndarray]#
- hconv_enc: List[foscat.SphericalStencil.SphericalStencil] = []#
- hconv_dec: List[foscat.SphericalStencil.SphericalStencil] = []#
- enc_w1#
- enc_bn1#
- enc_w2#
- enc_bn2#
- enc_nsides#
- dec_w1#
- dec_bn1#
- dec_w2#
- dec_bn2#
- head_hconv#
- head_w#
- head_bn#
- head_reduce#
- runtime_device#
- set_device(device: torch.device | str) torch.device[source]#
Request a (re)device; will probe Foscat and return the actual runtime device used.
- forward_any(x, cell_ids: numpy.ndarray | None = None)[source]#
If x is a Tensor (B,C,N): standard batched path (requires same N for all). If x is a list of Tensors: variable-length per-sample path, returns a list of outputs.
- forward(x: torch.Tensor, cell_ids: numpy.ndarray | None = None) torch.Tensor[source]#
Forward pass.
- Parameters:
x (
torch.Tensor,shape (B,C_in,Npix)) – Input tensor on in_nside grid.cell_ids (
np.ndarray (B,Npix) optional,use another cell_ids than the initial one.) – if None use the initial cell_ids.
- predict(x: torch.Tensor, batch_size: int = 8, cell_ids: numpy.ndarray | None = None) torch.Tensor[source]#
- extract_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0)[source]#
Extract raw convolution kernels for a given stage/level/conv.
- Parameters:
- Returns:
np.ndarray– Array of shape (in_c, out_c, K, K) containing the spatial kernels.
- plot_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0, fixed: str = 'in', index: int = 0, max_tiles: int = 16)[source]#
Quick visualization of kernels on a grid using matplotlib.
- Parameters:
stage (
{"encoder", "decoder"}) – Which tower to visualize.layer (
int) – Level to visualize.conv (
int) – 0 or 1: first or second conv in the level.fixed (
{"in", "out"}) – If “in”, show kernels for a fixed input channel across many outputs. If “out”, show kernels for a fixed output channel across many inputs.index (
int) – Channel index to fix (according to fixed).max_tiles (
int) – Maximum number of tiles to display.
- class foscat.healpix_unet_torch.TestUNET(methodName='runTest')#
Bases:
unittest.TestCaseLightweight smoke tests for shape and parameter plumbing.
- setUp()#
Hook method for setting up the test fixture before exercising it.
- test_forward_shape()#
- test_param_roundtrip_and_determinism()#
- class foscat.healpix_unet_torch.HealpixDataset(x, y, cell_ids=None, dtype=torch.float32)[source]#
Bases:
torch.utils.data.DatasetReturns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y). Shapes:
x: (C, Npix) y: (C_out or 1, Npix) cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))
- x#
- y#
- class foscat.healpix_unet_torch.HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: numpy.ndarray | torch.Tensor | None = None, dtype: torch.dtype = torch.float32)[source]#
Bases:
torch.utils.data.DatasetFixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix)
- class foscat.healpix_unet_torch.HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: numpy.ndarray | torch.Tensor | None = None, dtype: torch.dtype = torch.float32)[source]#
Bases:
torch.utils.data.DatasetFixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix)
- class foscat.healpix_unet_torch.VarLenHealpixDataset(x_list: List[numpy.ndarray | torch.Tensor], y_list: List[numpy.ndarray | torch.Tensor], cids_list: List[numpy.ndarray | torch.Tensor] | None = None, dtype: torch.dtype = torch.float32)[source]#
Bases:
torch.utils.data.DatasetVariable-length per-sample dataset.
x_list[b]: (C, Npix_b) or (1, C, Npix_b) y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)
For multi-class segmentation with CrossEntropyLoss, you may pass class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).
cids_list[b]: (Npix_b,) or None
- x#
- y#
- class foscat.healpix_unet_torch.VarLenHealpixDataset(x_list, y_list, cids_list=None, dtype=torch.float32)[source]#
Bases:
torch.utils.data.Datasetx_list: list of (C, Npix_b) tensors or arrays y_list: list of (C_out or 1, Npix_b) tensors or arrays cids_list: optional list of (Npix_b,) arrays
- x#
- y#
- c = None#
- foscat.healpix_unet_torch.varlen_collate(batch)[source]#
Collate for variable-length samples: keep lists, do NOT stack. Returns lists: xs, ys, cs (cs can be None).
- foscat.healpix_unet_torch.fit(model, x_train: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray], y_train: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray], *, cell_ids_train: numpy.ndarray | torch.Tensor | List[numpy.ndarray | torch.Tensor] | None = None, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, x_valid: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray] = None, y_valid: torch.Tensor | numpy.ndarray | List[torch.Tensor | numpy.ndarray] = None, save_model: bool = False, lr: float = 0.001, weight_decay: float = 0.0, clip_grad_norm: float | None = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM') dict[source]#
- Train helper that supports:
Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.
Variable-length lists: x=[(C,N_b)], y=[…], cell_ids=[(N_b,)], returning per-sample grids.
ADAM: standard minibatch update. LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.
Notes
For segmentation with multiple classes, pass integer class targets for y: fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).
For regression, pass float targets with the same (C_out, N) channeling.