foscat.healpix_vit_torch#
Attributes#
Classes#
HEALPix Vision Transformer (Foscat-based) with variable channel widths per level |
Module Contents#
- class foscat.healpix_vit_torch.HealpixViT(*, in_nside: int, n_chan_in: int, level_dims: List[int], depth: int, num_heads: int, cell_ids: numpy.ndarray, task: Literal['regression', 'segmentation', 'global'] = 'regression', out_channels: int = 1, mlp_ratio: float = 4.0, KERNELSZ: int = 3, gauge_type: Literal['cosmo', 'phi'] = 'cosmo', G: int = 1, prefer_foscat_gpu: bool = True, cls_token: bool = False, pos_embed: Literal['learned', 'none'] = 'learned', head_type: Literal['mean', 'cls'] = 'mean', dropout: float = 0.0, dtype: Literal['float32', 'float64'] = 'float32')[source]#
Bases:
torch.nn.ModuleHEALPix Vision Transformer (Foscat-based) with variable channel widths per level and a U-Net-like spherical decoder.
Key idea#
Encoder uses a list of channel dimensions level_dims = [C_fine, C_l1, …, C_token] that evolve with depth (e.g., 128 -> 192 -> 256).
At each encoder level (before a Down()), we apply a spherical convolution that maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level.
Transformer runs at the token grid with embedding dim = C_token.
Decoder upsamples one level at a time; after each Up() it concatenates the upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies a spherical convolution to fuse (C_{i+1} + C_i) -> C_i.
Final head maps C_fine -> out_channels at the finest grid.
Shapes (dense tasks)#
- Input(B, Cin, Nfine)
→ patch-embed (Cin -> C_fine) at finest grid → for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i) → tokens at grid L with dim C_token → Transformer on tokens (C_token) → for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i → Head: C_fine -> out_channels at finest grid
Requirements#
- level_dims length must be token_down+1, with:
len(level_dims) = token_down + 1 level_dims[0] = channels at finest grid after patch embedding level_dims[-1] = Transformer embedding dimension
Each value in level_dims must be divisible by G (number of gauges).
out_channels must be divisible by G.
Parameters (main)#
in_nside : input HEALPix nside (nested) n_chan_in : input channels at finest grid (Cin) level_dims : list of ints, channel width per level from fine to token depth : number of Transformer encoder layers num_heads : self-attention heads cell_ids : finest-level nested indices (Nfine = 12*nside^2) task : “regression” | “segmentation” | “global” out_channels : output channels for dense tasks KERNELSZ : spherical kernel size for Foscat convolutions gauge_type : “cosmo” | “phi” G : number of gauges
- in_nside#
- n_chan_in#
- level_dims#
- depth#
- num_heads#
- task = 'regression'#
- out_channels = 1#
- mlp_ratio#
- KERNELSZ = 3#
- gauge_type = 'cosmo'#
- G = 1#
- prefer_foscat_gpu = True#
- cls_token_enabled = False#
- pos_embed_type = 'learned'#
- head_type = 'mean'#
- dropout#
- dtype = 'float32'#
- token_down#
- embed_dim#
- cell_ids_fine#
- f#
- hconv_levels: List[foscat.SphericalStencil.SphericalStencil] = []#
- level_cell_ids: List[numpy.ndarray]#
- token_nside#
- token_cell_ids#
- hconv_token#
- hconv_head#
- patch_w#
- patch_bn#
- enc_w: torch.nn.ParameterList#
- enc_bn: torch.nn.ModuleList#
- n_tokens#
- encoder#
- token_proj#
- dec_w: torch.nn.ParameterList#
- dec_bn: torch.nn.ModuleList#
- runtime_device#
- forward(x: torch.Tensor, runtime_ids: numpy.ndarray | None = None) torch.Tensor[source]#
x: (B, Cin, Nfine), nested ordering runtime_ids: optional fine-level ids to decode onto (defaults to training ids)
- predict(x: torch.Tensor | numpy.ndarray, batch_size: int = 8) torch.Tensor[source]#
- foscat.healpix_vit_torch.in_nside = 4#