foscat.healpix_unet_torch ========================= .. py:module:: foscat.healpix_unet_torch .. autoapi-nested-parse:: HEALPix U-Net (nested) with Foscat + PyTorch niceties ---------------------------------------------------- GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only. - ReLU + BatchNorm after each convolution (encoder & decoder) - Segmentation/Regression heads with optional final activation - PyTorch-ified: inherits from nn.Module, standard state_dict - Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU Shape convention: (B, C, Npix) Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors) Classes ------- .. autoapisummary:: foscat.healpix_unet_torch.HealpixUNet foscat.healpix_unet_torch.TestUNET foscat.healpix_unet_torch.HealpixDataset foscat.healpix_unet_torch.HealpixDataset foscat.healpix_unet_torch.HealpixDataset foscat.healpix_unet_torch.VarLenHealpixDataset foscat.healpix_unet_torch.VarLenHealpixDataset Functions --------- .. autoapisummary:: foscat.healpix_unet_torch.varlen_collate foscat.healpix_unet_torch.varlen_collate foscat.healpix_unet_torch.fit Module Contents --------------- .. py:class:: HealpixUNet(*, in_nside: int, n_chan_in: int, chanlist: List[int], cell_ids: numpy.ndarray, KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None, device: Optional[torch.device | str] = None, prefer_foscat_gpu: bool = True, gauge_type: Optional[Literal['cosmo', 'phi']] = 'cosmo', G: int = 1, down_type: Optional[Literal['mean', 'max']] = 'max', dtype: Literal['float32', 'float64'] = 'float32', head_reduce: Literal['mean', 'learned'] = 'mean') Bases: :py:obj:`torch.nn.Module` U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions. :Parameters: * **in_nside** (:py:class:`int`) -- Input HEALPix nside (nested scheme). * **n_chan_in** (:py:class:`int`) -- Number of input channels. * **chanlist** (:py:class:`list[int]`) -- Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64]. * **cell_ids** (:py:class:`np.ndarray`) -- Cell indices for the finest resolution (nside = in_nside) in nested scheme. * **KERNELSZ** (:py:class:`int`, *default* ``3``) -- Spatial kernel size K (K x K) for oriented convolution. * **gauge_type** (:py:class:`str`) -- Type of gauge : 'cosmo' use the same definition than https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html 'phi' is define at the pole, could be better for earth observation not using intensivly the pole * **G** (:py:class:`int`, *default* ``1``) -- Number of gauges for the orientation definition. * **task** (``{'regression','segmentation'}``, *default* ``'regression'``) -- Chooses the head and default activation. * **out_channels** (:py:class:`int`, *default* ``1``) -- Number of output channels (e.g. num_classes for segmentation). * **final_activation** (``{'none','sigmoid','softmax'}`` | None) -- If None, uses sensible default per task: 'none' for regression, 'softmax' for segmentation (multi-class), 'sigmoid' for segmentation when out_channels==1. * **device** (:py:class:`str | torch.device | None`, *default*: ``'cuda'`` if available else ``'cpu'``) -- Preferred device. The module will probe whether Foscat ops can run on CUDA; if not, it will fallback to CPU and keep all parameters/buffers on CPU for consistency. * **down_type** -- {"mean","max"}, default "max". Equivalent of max poll during down * **prefer_foscat_gpu** (:py:class:`bool`, *default* :py:obj:`True`) -- When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run. If the dry-run fails, everything falls back to CPU. .. admonition:: Notes - Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU. - Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``. - Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``SphericalStencil.Convol_torch``. - Foscat ops device is auto-probed to avoid CPU/CUDA mismatches. .. py:attribute:: dtype :value: 'float32' .. py:attribute:: gauge_type :value: 'cosmo' .. py:attribute:: G :value: 1 .. py:attribute:: in_nside .. py:attribute:: n_chan_in .. py:attribute:: chanlist .. py:attribute:: KERNELSZ :value: 3 .. py:attribute:: task :value: 'regression' .. py:attribute:: out_channels :value: 1 .. py:attribute:: prefer_foscat_gpu :value: True .. py:attribute:: device .. py:attribute:: f .. py:attribute:: l_cell_ids :type: List[numpy.ndarray] .. py:attribute:: hconv_enc :type: List[foscat.SphericalStencil.SphericalStencil] :value: [] .. py:attribute:: hconv_dec :type: List[foscat.SphericalStencil.SphericalStencil] :value: [] .. py:attribute:: enc_w1 .. py:attribute:: enc_bn1 .. py:attribute:: enc_w2 .. py:attribute:: enc_bn2 .. py:attribute:: enc_nsides .. py:attribute:: dec_w1 .. py:attribute:: dec_bn1 .. py:attribute:: dec_w2 .. py:attribute:: dec_bn2 .. py:attribute:: head_hconv .. py:attribute:: head_w .. py:attribute:: head_bn .. py:attribute:: head_reduce .. py:attribute:: runtime_device .. py:method:: set_device(device: torch.device | str) -> torch.device Request a (re)device; will probe Foscat and return the actual runtime device used. .. py:method:: forward_any(x, cell_ids: Optional[numpy.ndarray] = None) If `x` is a Tensor (B,C,N): standard batched path (requires same N for all). If `x` is a list of Tensors: variable-length per-sample path, returns a list of outputs. .. py:method:: forward(x: torch.Tensor, cell_ids: Optional[numpy.ndarray] = None) -> torch.Tensor Forward pass. :Parameters: * **x** (:py:class:`torch.Tensor`, :py:class:`shape (B`, :py:class:`C_in`, :py:class:`Npix)`) -- Input tensor on `in_nside` grid. * **cell_ids** (:py:class:`np.ndarray (B`, :py:class:`Npix) optional`, :py:class:`use another cell_ids than the initial one.`) -- if None use the initial cell_ids. .. py:method:: predict(x: torch.Tensor, batch_size: int = 8, cell_ids: Optional[numpy.ndarray] = None) -> torch.Tensor .. py:method:: to_tensor(x) .. py:method:: to_numpy(x) .. py:method:: extract_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0) Extract raw convolution kernels for a given stage/level/conv. :Parameters: * **stage** (``{"encoder", "decoder"}``) -- Which part of the network to inspect. * **layer** (:py:class:`int`) -- Pyramid level (0 = finest encoder level / bottommost decoder level). * **conv** (:py:class:`int`) -- 0 for the first conv at that level, 1 for the second conv. :returns: :py:class:`np.ndarray` -- Array of shape (in_c, out_c, K, K) containing the spatial kernels. .. py:method:: plot_kernels(stage: str = 'encoder', layer: int = 0, conv: int = 0, fixed: str = 'in', index: int = 0, max_tiles: int = 16) Quick visualization of kernels on a grid using matplotlib. :Parameters: * **stage** (``{"encoder", "decoder"}``) -- Which tower to visualize. * **layer** (:py:class:`int`) -- Level to visualize. * **conv** (:py:class:`int`) -- 0 or 1: first or second conv in the level. * **fixed** (``{"in", "out"}``) -- If "in", show kernels for a fixed input channel across many outputs. If "out", show kernels for a fixed output channel across many inputs. * **index** (:py:class:`int`) -- Channel index to fix (according to `fixed`). * **max_tiles** (:py:class:`int`) -- Maximum number of tiles to display. .. py:class:: TestUNET(methodName='runTest') Bases: :py:obj:`unittest.TestCase` Lightweight smoke tests for shape and parameter plumbing. .. py:method:: setUp() Hook method for setting up the test fixture before exercising it. .. py:method:: test_forward_shape() .. py:method:: test_param_roundtrip_and_determinism() .. py:class:: HealpixDataset(x, y, cell_ids=None, dtype=torch.float32) Bases: :py:obj:`torch.utils.data.Dataset` Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y). Shapes: x: (C, Npix) y: (C_out or 1, Npix) cell_ids: (Npix,) per-sample (or broadcasted from (Npix,)) .. py:attribute:: x .. py:attribute:: y .. py:class:: HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: Optional[Union[numpy.ndarray, torch.Tensor]] = None, dtype: torch.dtype = torch.float32) Bases: :py:obj:`torch.utils.data.Dataset` Fixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids). x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix) .. py:class:: HealpixDataset(x: torch.Tensor, y: torch.Tensor, cell_ids: Optional[Union[numpy.ndarray, torch.Tensor]] = None, dtype: torch.dtype = torch.float32) Bases: :py:obj:`torch.utils.data.Dataset` Fixed-grid dataset (common Npix for all samples). Returns (x, y) if cell_ids is None, else (x, y, cell_ids). x: (B, C, Npix) y: (B, C_out or 1, Npix) or class indices depending on task cell_ids: (Npix,) or (B, Npix) .. py:class:: VarLenHealpixDataset(x_list: List[Union[numpy.ndarray, torch.Tensor]], y_list: List[Union[numpy.ndarray, torch.Tensor]], cids_list: Optional[List[Union[numpy.ndarray, torch.Tensor]]] = None, dtype: torch.dtype = torch.float32) Bases: :py:obj:`torch.utils.data.Dataset` Variable-length per-sample dataset. x_list[b]: (C, Npix_b) or (1, C, Npix_b) y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets) For multi-class segmentation with CrossEntropyLoss, you may pass class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later). cids_list[b]: (Npix_b,) or None .. py:attribute:: x .. py:attribute:: y .. py:class:: VarLenHealpixDataset(x_list, y_list, cids_list=None, dtype=torch.float32) Bases: :py:obj:`torch.utils.data.Dataset` x_list: list of (C, Npix_b) tensors or arrays y_list: list of (C_out or 1, Npix_b) tensors or arrays cids_list: optional list of (Npix_b,) arrays .. py:attribute:: x .. py:attribute:: y .. py:attribute:: c :value: None .. py:function:: varlen_collate(batch) .. py:function:: varlen_collate(batch) Collate for variable-length samples: keep lists, do NOT stack. Returns lists: xs, ys, cs (cs can be None). .. py:function:: fit(model, x_train: Union[torch.Tensor, numpy.ndarray, List[Union[torch.Tensor, numpy.ndarray]]], y_train: Union[torch.Tensor, numpy.ndarray, List[Union[torch.Tensor, numpy.ndarray]]], *, cell_ids_train: Optional[Union[numpy.ndarray, torch.Tensor, List[Union[numpy.ndarray, torch.Tensor]]]] = None, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, x_valid: Union[torch.Tensor, numpy.ndarray, List[Union[torch.Tensor, numpy.ndarray]]] = None, y_valid: Union[torch.Tensor, numpy.ndarray, List[Union[torch.Tensor, numpy.ndarray]]] = None, save_model: bool = False, lr: float = 0.001, weight_decay: float = 0.0, clip_grad_norm: Optional[float] = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM') -> dict Train helper that supports: - Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids. - Variable-length lists: x=[(C,N_b)], y=[...], cell_ids=[(N_b,)], returning per-sample grids. ADAM: standard minibatch update. LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch. .. admonition:: Notes - For segmentation with multiple classes, pass integer class targets for y: fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b). - For regression, pass float targets with the same (C_out, N) channeling.