foscat.healpix_vit_torch ======================== .. py:module:: foscat.healpix_vit_torch Attributes ---------- .. autoapisummary:: foscat.healpix_vit_torch.in_nside Classes ------- .. autoapisummary:: foscat.healpix_vit_torch.HealpixViT Module Contents --------------- .. py:class:: HealpixViT(*, in_nside: int, n_chan_in: int, level_dims: List[int], depth: int, num_heads: int, cell_ids: numpy.ndarray, task: Literal['regression', 'segmentation', 'global'] = 'regression', out_channels: int = 1, mlp_ratio: float = 4.0, KERNELSZ: int = 3, gauge_type: Literal['cosmo', 'phi'] = 'cosmo', G: int = 1, prefer_foscat_gpu: bool = True, cls_token: bool = False, pos_embed: Literal['learned', 'none'] = 'learned', head_type: Literal['mean', 'cls'] = 'mean', dropout: float = 0.0, dtype: Literal['float32', 'float64'] = 'float32') Bases: :py:obj:`torch.nn.Module` HEALPix Vision Transformer (Foscat-based) with *variable channel widths per level* and a U-Net-like spherical decoder. Key idea -------- - Encoder uses a list of channel dimensions `level_dims = [C_fine, C_l1, ..., C_token]` that evolve *with depth* (e.g., 128 -> 192 -> 256). - At each encoder level (before a Down()), we apply a spherical convolution that maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level. - Transformer runs at the token grid with embedding dim = C_token. - Decoder upsamples *one level at a time*; after each Up() it concatenates the upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies a spherical convolution to fuse (C_{i+1} + C_i) -> C_i. - Final head maps C_fine -> out_channels at the finest grid. Shapes (dense tasks) -------------------- Input : (B, Cin, Nfine) → patch-embed (Cin -> C_fine) at finest grid → for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i) → tokens at grid L with dim C_token → Transformer on tokens (C_token) → for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i → Head: C_fine -> out_channels at finest grid Requirements ------------ - level_dims length must be token_down+1, with: len(level_dims) = token_down + 1 level_dims[0] = channels at finest grid after patch embedding level_dims[-1] = Transformer embedding dimension - Each value in level_dims must be divisible by G (number of gauges). - out_channels must be divisible by G. Parameters (main) ----------------- in_nside : input HEALPix nside (nested) n_chan_in : input channels at finest grid (Cin) level_dims : list of ints, channel width per level from fine to token depth : number of Transformer encoder layers num_heads : self-attention heads cell_ids : finest-level nested indices (Nfine = 12*nside^2) task : "regression" | "segmentation" | "global" out_channels : output channels for dense tasks KERNELSZ : spherical kernel size for Foscat convolutions gauge_type : "cosmo" | "phi" G : number of gauges .. py:attribute:: in_nside .. py:attribute:: n_chan_in .. py:attribute:: level_dims .. py:attribute:: depth .. py:attribute:: num_heads .. py:attribute:: task :value: 'regression' .. py:attribute:: out_channels :value: 1 .. py:attribute:: mlp_ratio .. py:attribute:: KERNELSZ :value: 3 .. py:attribute:: gauge_type :value: 'cosmo' .. py:attribute:: G :value: 1 .. py:attribute:: prefer_foscat_gpu :value: True .. py:attribute:: cls_token_enabled :value: False .. py:attribute:: pos_embed_type :value: 'learned' .. py:attribute:: head_type :value: 'mean' .. py:attribute:: dropout .. py:attribute:: dtype :value: 'float32' .. py:attribute:: token_down .. py:attribute:: embed_dim .. py:attribute:: cell_ids_fine .. py:attribute:: f .. py:attribute:: hconv_levels :type: List[foscat.SphericalStencil.SphericalStencil] :value: [] .. py:attribute:: level_cell_ids :type: List[numpy.ndarray] .. py:attribute:: token_nside .. py:attribute:: token_cell_ids .. py:attribute:: hconv_token .. py:attribute:: hconv_head .. py:attribute:: patch_w .. py:attribute:: patch_bn .. py:attribute:: enc_w :type: torch.nn.ParameterList .. py:attribute:: enc_bn :type: torch.nn.ModuleList .. py:attribute:: n_tokens .. py:attribute:: encoder .. py:attribute:: token_proj .. py:attribute:: dec_w :type: torch.nn.ParameterList .. py:attribute:: dec_bn :type: torch.nn.ModuleList .. py:attribute:: runtime_device .. py:method:: forward(x: torch.Tensor, runtime_ids: Optional[numpy.ndarray] = None) -> torch.Tensor x: (B, Cin, Nfine), nested ordering runtime_ids: optional fine-level ids to decode onto (defaults to training ids) .. py:method:: predict(x: Union[torch.Tensor, numpy.ndarray], batch_size: int = 8) -> torch.Tensor .. py:data:: in_nside :value: 4