foscat.UNET =========== .. py:module:: foscat.UNET .. autoapi-nested-parse:: UNET for HEALPix (nested) using Foscat oriented convolutions. This module defines a lightweight, U-Net–like encoder/decoder that operates on signals defined on the HEALPix sphere (nested scheme). It leverages Foscat's `HOrientedConvol` for orientation-aware convolutions and `funct` utilities for upgrade/downgrade (change of nside) operations. Key design choices ------------------ • **Flat parameter vector**: all convolution kernels are stored in a single 1‑D vector `self.x`. The dictionaries `self.wconv` and `self.t_wconv` map layer indices to *offsets* within that vector. • **HEALPix-aware down/up-sampling**: down-sampling uses `self.f.ud_grade_2`, and up-sampling uses `self.f.up_grade`, both with per-level `cell_ids` to preserve locality and orientation information. • **Skip connections**: U‑Net skip connections are implemented by concatenating encoder features with downgraded/upsampled paths along the channel axis. Shape convention ---------------- All tensors follow the Foscat backend shape `(batch, channels, npix)`. Dependencies ------------ - foscat.scat_cov as `sc` - foscat.SphericalStencil as `hs` .. admonition:: Example >>> import numpy as np >>> from UNET import UNET >>> nside = 8 >>> npix = 12 * nside * nside >>> # Your backend tensor should be created via foscat backend; here we show a placeholder np.array >>> x = np.random.randn(1, 1, npix).astype(np.float32) >>> # cell_ids should be provided for the highest resolution (nside) >>> # and must be consistent with the nested scheme expected by Foscat. >>> # Example placeholder (use the real one from your pipeline): >>> cell_ids = np.arange(npix, dtype=np.int64) >>> net = UNET(in_nside=nside, n_chan_in=1, cell_ids=cell_ids) >>> y = net.eval(net.f.backend.bk_cast(x)) # forward pass .. admonition:: Notes - This implementation assumes `cell_ids` is provided for the input resolution `in_nside`. It propagates/derives the coarser `cell_ids` across levels. - Some constructor parameters are reserved for future use (see docstring). Classes ------- .. autoapisummary:: foscat.UNET.UNET foscat.UNET.TestUNET Module Contents --------------- .. py:class:: UNET(nparam: int = 1, KERNELSZ: int = 3, NORIENT: int = 4, chanlist: Optional[list] = None, in_nside: int = 1, n_chan_in: int = 1, cell_ids: Optional[numpy.ndarray] = None, SEED: int = 1234, filename: Optional[str] = None) U‑Net–like network on HEALPix (nested) using Foscat oriented convolutions. The network is built as an encoder/decoder (down/upsampling) tower. Each level performs two oriented convolutions. All kernels are packed in a flat parameter vector `self.x` to simplify optimization with external solvers. :Parameters: * **nparam** (:py:class:`int`, *optional*) -- Reserved for future use. Currently unused. * **KERNELSZ** (:py:class:`int`, *optional*) -- Spatial kernel size (k × k) used by oriented convolutions. Default is 3. * **NORIENT** (:py:class:`int`, *optional*) -- Reserved for future use (number of orientations). Currently unused. * **chanlist** (:py:class:`Optional[list[int]]`, *optional*) -- Number of output channels per encoder level. If ``None``, it defaults to ``[4 * 2**k for k in range(log2(in_nside))]``. The length of this list defines the number of encoder/decoder levels. * **in_nside** (:py:class:`int`, *optional*) -- Input HEALPix nside. Must be a power of two for the implicit ``log2(in_nside)`` depth when ``chanlist`` is not given. * **n_chan_in** (:py:class:`int`, *optional*) -- Number of input channels at the finest resolution. Default is 1. * **cell_ids** (:py:class:`array-like` of :py:class:`int`, :py:class:`required`) -- Pixel identifiers at the input resolution (nested indexing). They are used to build oriented convolutions and to derive coarser grids. **Must not be ``None``.** * **SEED** (:py:class:`int`, *optional*) -- Reserved for future use (random initialization seed). Currently unused. * **filename** (:py:class:`Optional[str]`, *optional*) -- Reserved for future use (checkpoint I/O). Currently unused. .. attribute:: f Foscat helper exposing the backend and grade/convolution utils. :type: :py:class:`object` .. attribute:: KERNELSZ Effective kernel size used by all convolutions. :type: :py:class:`int` .. attribute:: chanlist Channels per encoder level. :type: :py:class:`list[int]` .. attribute:: wconv, t_wconv Offsets into the flat parameter vector `self.x` for encoder/decoder convolutions respectively. :type: :py:class:`Dict[int`, :py:class:`int]` .. attribute:: hconv, t_hconv Per-level oriented convolution operators for encoder/decoder. :type: :py:class:`Dict[int`, :py:class:`hs.SphericalPencil]` .. attribute:: l_cell_ids Per-level cell ids for downsampled grids (encoder side). :type: :py:class:`Dict[int`, :py:class:`np.ndarray]` .. attribute:: m_cell_ids Per-level cell ids for upsampled grids (decoder side). Mirrors levels of ``l_cell_ids`` but indexed from the decoder traversal. :type: :py:class:`Dict[int`, :py:class:`np.ndarray]` .. attribute:: x Flat vector holding *all* convolution weights. :type: :py:class:`backend tensor (1‑D)` .. attribute:: nside Input nside (finest resolution). :type: :py:class:`int` .. attribute:: n_chan_in Number of channels at input. :type: :py:class:`int` .. admonition:: Notes - The constructor prints informative messages about the architecture layout (channels and pixel counts) to ease debugging. - The implementation keeps the logic identical to the original code; only comments, docstrings and variable explanations are added. .. py:attribute:: f .. py:attribute:: KERNELSZ :value: 3 .. py:attribute:: n_cnn :value: 0 .. py:attribute:: l_cell_ids .. py:attribute:: wconv .. py:attribute:: hconv .. py:attribute:: m_cell_ids .. py:attribute:: t_wconv .. py:attribute:: t_hconv .. py:attribute:: x .. py:attribute:: nside :value: 1 .. py:attribute:: n_chan_in :value: 1 .. py:attribute:: chanlist :value: None .. py:method:: get_param() Return the flat parameter vector that stores all convolution kernels. :returns: :py:class:`backend tensor (1‑D)` -- The Foscat backend representation (e.g., NumPy/Torch/TF tensor) holding all convolution weights in a single vector. .. py:method:: set_param(x) Overwrite the flat parameter vector with externally provided values. This is useful when optimizing parameters with an external optimizer or when restoring weights from a checkpoint (after proper conversion to the Foscat backend type). :Parameters: **x** (:py:class:`array-like (1‑D)`) -- New values for the flat parameter vector. Must match `self.x` size. .. py:method:: eval(data) Run a forward pass through the encoder/decoder. :Parameters: **data** (:py:class:`backend tensor`, :py:class:`shape (B`, :py:class:`C`, :py:class:`Npix)`) -- Input signal at resolution `self.nside` (finest grid). `C` must equal `self.n_chan_in`. :returns: :py:class:`backend tensor`, :py:class:`shape (B`, :py:class:`C_out`, :py:class:`Npix)` -- Network output at the input resolution. `C_out` is `1` at the top level, or `1 + chanlist[level]` for intermediate decoder levels. .. admonition:: Notes The forward comprises two stages: (1) **Encoder**: for each level `l`, apply two oriented convolutions ("conv -> conv"), downsample to the next coarser grid, and concatenate with a downgraded copy of the running input (`m_data`). (2) **Decoder**: for each level, upsample to the finer grid, concatenate with the stored encoder feature (skip connection), then apply two oriented convolutions ("conv -> conv") to produce `out_chan`. .. py:method:: to_tensor(x) .. py:method:: to_numpy(x) .. py:class:: TestUNET(methodName='runTest') Bases: :py:obj:`unittest.TestCase` Lightweight smoke tests for shape and parameter plumbing. .. py:method:: setUp() Hook method for setting up the test fixture before exercising it. .. py:method:: test_forward_shape() .. py:method:: test_param_roundtrip_and_determinism()