foscat.unet_2_d_from_healpix_params#

Attributes#

Classes#

PlanarUNet

U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.

Functions#

fit(→ dict)

Training loop mirroring healpix_unet_torch.fit, adapted for 2D images.

Module Contents#

class foscat.unet_2_d_from_healpix_params.PlanarUNet(*, in_nside: int, n_chan_in: int, chanlist: List[int], KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Literal['none', 'sigmoid', 'softmax'] | None = None, device: torch.device | str | None = None, down_type: Literal['mean', 'max'] | None = 'max', dtype: Literal['float32', 'float64'] = 'float32', head_reduce: Literal['mean', 'learned'] = 'mean')[source]#

Bases: torch.nn.Module

U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.

Key compat points with HealpixUNet:
  • Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task, out_channels, final_activation, device, down_type, dtype, head_reduce.

  • Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv.

  • Downsampling by factor 2 at each level; upsampling mirrors back.

  • Head produces out_channels with optional BN and final activation.

Differences vs sphere version:
  • Operates on regular 2D images of size (3*in_nside, 4*in_nside).

  • Standard Conv2d instead of custom spherical stencil.

  • No gauges (G=1 implicit) and no cell_ids.

Shapes#

Input : (B, C_in, 3*in_nside, 4*in_nside) Output : (B, C_out, 3*in_nside, 4*in_nside)

Constraints#

in_nside must be divisible by 2**depth, where depth == len(chanlist).

in_nside#
n_chan_in#
chanlist#
KERNELSZ = 3#
task = 'regression'#
out_channels = 1#
down_type = 'max'#
dtype = Ellipsis#
head_reduce = 'mean'#
device#
skips_channels: List[int] = []#
encoder#
upconvs#
decoder#
head_conv#
head_bn#
head_mixer#
to_tensor(x)[source]#
to_numpy(x)[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside

predict(x: torch.Tensor, batch_size: int = 8) torch.Tensor[source]#
foscat.unet_2_d_from_healpix_params.fit(model: torch.nn.Module, x_train: torch.Tensor | numpy.ndarray, y_train: torch.Tensor | numpy.ndarray, *, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, lr: float = 0.001, weight_decay: float = 0.0, clip_grad_norm: float | None = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM') dict[source]#

Training loop mirroring healpix_unet_torch.fit, adapted for 2D images.

  • Fixed inputs: tensors/ndarrays of the same size (B, C, H, W) with H=3*nside, W=4*nside

  • Perte: MSE (regression) / BCE(BCEWithLogits si final_activation=’none’) / CrossEntropy (multiclasses)

  • Optimiseur: ADAM ou LBFGS avec closure

  • Logs: renvoie {“loss”: history}

foscat.unet_2_d_from_healpix_params.nside = 32#