foscat.unet_2_d_from_healpix_params#
Attributes#
Classes#
U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet. |
Functions#
|
Training loop mirroring healpix_unet_torch.fit, adapted for 2D images. |
Module Contents#
- class foscat.unet_2_d_from_healpix_params.PlanarUNet(*, in_nside: int, n_chan_in: int, chanlist: List[int], KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Literal['none', 'sigmoid', 'softmax'] | None = None, device: torch.device | str | None = None, down_type: Literal['mean', 'max'] | None = 'max', dtype: Literal['float32', 'float64'] = 'float32', head_reduce: Literal['mean', 'learned'] = 'mean')[source]#
Bases:
torch.nn.ModuleU-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.
- Key compat points with HealpixUNet:
Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task, out_channels, final_activation, device, down_type, dtype, head_reduce.
Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv.
Downsampling by factor 2 at each level; upsampling mirrors back.
Head produces out_channels with optional BN and final activation.
- Differences vs sphere version:
Operates on regular 2D images of size (3*in_nside, 4*in_nside).
Standard Conv2d instead of custom spherical stencil.
No gauges (G=1 implicit) and no cell_ids.
Shapes#
Input : (B, C_in, 3*in_nside, 4*in_nside) Output : (B, C_out, 3*in_nside, 4*in_nside)
Constraints#
in_nside must be divisible by 2**depth, where depth == len(chanlist).
- in_nside#
- n_chan_in#
- chanlist#
- KERNELSZ = 3#
- task = 'regression'#
- out_channels = 1#
- down_type = 'max'#
- dtype = Ellipsis#
- head_reduce = 'mean'#
- device#
- encoder#
- upconvs#
- decoder#
- head_conv#
- head_bn#
- head_mixer#
- forward(x: torch.Tensor) torch.Tensor[source]#
x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside
- predict(x: torch.Tensor, batch_size: int = 8) torch.Tensor[source]#
- foscat.unet_2_d_from_healpix_params.fit(model: torch.nn.Module, x_train: torch.Tensor | numpy.ndarray, y_train: torch.Tensor | numpy.ndarray, *, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, lr: float = 0.001, weight_decay: float = 0.0, clip_grad_norm: float | None = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM') dict[source]#
Training loop mirroring healpix_unet_torch.fit, adapted for 2D images.
Fixed inputs: tensors/ndarrays of the same size (B, C, H, W) with H=3*nside, W=4*nside
Perte: MSE (regression) / BCE(BCEWithLogits si final_activation=’none’) / CrossEntropy (multiclasses)
Optimiseur: ADAM ou LBFGS avec closure
Logs: renvoie {“loss”: history}
- foscat.unet_2_d_from_healpix_params.nside = 32#