foscat.UNET#
UNET for HEALPix (nested) using Foscat oriented convolutions.
This module defines a lightweight, U-Net–like encoder/decoder that operates on signals defined on the HEALPix sphere (nested scheme). It leverages Foscat’s HOrientedConvol for orientation-aware convolutions and funct utilities for upgrade/downgrade (change of nside) operations.
Key design choices#
Flat parameter vector: all convolution kernels are stored in a single 1‑D vector self.x. The dictionaries self.wconv and self.t_wconv map layer indices to offsets within that vector.
HEALPix-aware down/up-sampling: down-sampling uses self.f.ud_grade_2, and up-sampling uses self.f.up_grade, both with per-level cell_ids to preserve locality and orientation information.
Skip connections: U‑Net skip connections are implemented by concatenating encoder features with downgraded/upsampled paths along the channel axis.
Shape convention#
All tensors follow the Foscat backend shape (batch, channels, npix).
Dependencies#
foscat.scat_cov as sc
foscat.SphericalStencil as hs
Example
>>> import numpy as np
>>> from UNET import UNET
>>> nside = 8
>>> npix = 12 * nside * nside
>>> # Your backend tensor should be created via foscat backend; here we show a placeholder np.array
>>> x = np.random.randn(1, 1, npix).astype(np.float32)
>>> # cell_ids should be provided for the highest resolution (nside)
>>> # and must be consistent with the nested scheme expected by Foscat.
>>> # Example placeholder (use the real one from your pipeline):
>>> cell_ids = np.arange(npix, dtype=np.int64)
>>> net = UNET(in_nside=nside, n_chan_in=1, cell_ids=cell_ids)
>>> y = net.eval(net.f.backend.bk_cast(x)) # forward pass
Notes
This implementation assumes cell_ids is provided for the input resolution in_nside. It propagates/derives the coarser cell_ids across levels.
Some constructor parameters are reserved for future use (see docstring).
Classes#
Module Contents#
- class foscat.UNET.UNET(nparam: int = 1, KERNELSZ: int = 3, NORIENT: int = 4, chanlist: list | None = None, in_nside: int = 1, n_chan_in: int = 1, cell_ids: numpy.ndarray | None = None, SEED: int = 1234, filename: str | None = None)#
U‑Net–like network on HEALPix (nested) using Foscat oriented convolutions.
The network is built as an encoder/decoder (down/upsampling) tower. Each level performs two oriented convolutions. All kernels are packed in a flat parameter vector self.x to simplify optimization with external solvers.
- Parameters:
nparam (
int, optional) – Reserved for future use. Currently unused.KERNELSZ (
int, optional) – Spatial kernel size (k × k) used by oriented convolutions. Default is 3.NORIENT (
int, optional) – Reserved for future use (number of orientations). Currently unused.chanlist (
Optional[list[int]], optional) – Number of output channels per encoder level. IfNone, it defaults to[4 * 2**k for k in range(log2(in_nside))]. The length of this list defines the number of encoder/decoder levels.in_nside (
int, optional) – Input HEALPix nside. Must be a power of two for the implicitlog2(in_nside)depth whenchanlistis not given.n_chan_in (
int, optional) – Number of input channels at the finest resolution. Default is 1.cell_ids (
array-likeofint,required) – Pixel identifiers at the input resolution (nested indexing). They are used to build oriented convolutions and to derive coarser grids. Must not be ``None``.SEED (
int, optional) – Reserved for future use (random initialization seed). Currently unused.filename (
Optional[str], optional) – Reserved for future use (checkpoint I/O). Currently unused.
- chanlist#
Channels per encoder level.
- Type:
list[int]
- wconv, t_wconv
Offsets into the flat parameter vector self.x for encoder/decoder convolutions respectively.
- Type:
Dict[int,int]
- hconv, t_hconv
Per-level oriented convolution operators for encoder/decoder.
- Type:
Dict[int,hs.SphericalPencil]
- l_cell_ids#
Per-level cell ids for downsampled grids (encoder side).
- Type:
Dict[int,np.ndarray]
- m_cell_ids#
Per-level cell ids for upsampled grids (decoder side). Mirrors levels of
l_cell_idsbut indexed from the decoder traversal.- Type:
Dict[int,np.ndarray]
- x#
Flat vector holding all convolution weights.
- Type:
backend tensor (1‑D)
Notes
The constructor prints informative messages about the architecture layout (channels and pixel counts) to ease debugging.
The implementation keeps the logic identical to the original code; only comments, docstrings and variable explanations are added.
- f#
- KERNELSZ = 3#
- n_cnn = 0#
- l_cell_ids#
- wconv#
- hconv#
- m_cell_ids#
- t_wconv#
- t_hconv#
- x#
- nside = 1#
- n_chan_in = 1#
- chanlist = None#
- get_param()#
Return the flat parameter vector that stores all convolution kernels.
- Returns:
backend tensor (1‑D)– The Foscat backend representation (e.g., NumPy/Torch/TF tensor) holding all convolution weights in a single vector.
- set_param(x)#
Overwrite the flat parameter vector with externally provided values.
This is useful when optimizing parameters with an external optimizer or when restoring weights from a checkpoint (after proper conversion to the Foscat backend type).
- Parameters:
x (
array-like (1‑D)) – New values for the flat parameter vector. Must match self.x size.
- eval(data)#
Run a forward pass through the encoder/decoder.
- Parameters:
data (
backend tensor,shape (B,C,Npix)) – Input signal at resolution self.nside (finest grid). C must equal self.n_chan_in.- Returns:
backend tensor,shape (B,C_out,Npix)– Network output at the input resolution. C_out is 1 at the top level, or 1 + chanlist[level] for intermediate decoder levels.
Notes
The forward comprises two stages: (1) Encoder: for each level l, apply two oriented convolutions
(“conv -> conv”), downsample to the next coarser grid, and concatenate with a downgraded copy of the running input (m_data).
Decoder: for each level, upsample to the finer grid, concatenate with the stored encoder feature (skip connection), then apply two oriented convolutions (“conv -> conv”) to produce out_chan.
- to_tensor(x)#
- to_numpy(x)#
- class foscat.UNET.TestUNET(methodName='runTest')#
Bases:
unittest.TestCaseLightweight smoke tests for shape and parameter plumbing.
- setUp()#
Hook method for setting up the test fixture before exercising it.
- test_forward_shape()#
- test_param_roundtrip_and_determinism()#