foscat.healpix_vit_torch#

Attributes#

Classes#

HealpixViT

HEALPix Vision Transformer (Foscat-based) with variable channel widths per level

Module Contents#

class foscat.healpix_vit_torch.HealpixViT(*, in_nside: int, n_chan_in: int, level_dims: List[int], depth: int, num_heads: int, cell_ids: numpy.ndarray, task: Literal['regression', 'segmentation', 'global'] = 'regression', out_channels: int = 1, mlp_ratio: float = 4.0, KERNELSZ: int = 3, gauge_type: Literal['cosmo', 'phi'] = 'cosmo', G: int = 1, prefer_foscat_gpu: bool = True, cls_token: bool = False, pos_embed: Literal['learned', 'none'] = 'learned', head_type: Literal['mean', 'cls'] = 'mean', dropout: float = 0.0, dtype: Literal['float32', 'float64'] = 'float32')[source]#

Bases: torch.nn.Module

HEALPix Vision Transformer (Foscat-based) with variable channel widths per level and a U-Net-like spherical decoder.

Key idea#

  • Encoder uses a list of channel dimensions level_dims = [C_fine, C_l1, …, C_token] that evolve with depth (e.g., 128 -> 192 -> 256).

  • At each encoder level (before a Down()), we apply a spherical convolution that maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level.

  • Transformer runs at the token grid with embedding dim = C_token.

  • Decoder upsamples one level at a time; after each Up() it concatenates the upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies a spherical convolution to fuse (C_{i+1} + C_i) -> C_i.

  • Final head maps C_fine -> out_channels at the finest grid.

Shapes (dense tasks)#

Input(B, Cin, Nfine)

→ patch-embed (Cin -> C_fine) at finest grid → for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i) → tokens at grid L with dim C_token → Transformer on tokens (C_token) → for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i → Head: C_fine -> out_channels at finest grid

Requirements#

  • level_dims length must be token_down+1, with:

    len(level_dims) = token_down + 1 level_dims[0] = channels at finest grid after patch embedding level_dims[-1] = Transformer embedding dimension

  • Each value in level_dims must be divisible by G (number of gauges).

  • out_channels must be divisible by G.

Parameters (main)#

in_nside : input HEALPix nside (nested) n_chan_in : input channels at finest grid (Cin) level_dims : list of ints, channel width per level from fine to token depth : number of Transformer encoder layers num_heads : self-attention heads cell_ids : finest-level nested indices (Nfine = 12*nside^2) task : “regression” | “segmentation” | “global” out_channels : output channels for dense tasks KERNELSZ : spherical kernel size for Foscat convolutions gauge_type : “cosmo” | “phi” G : number of gauges

in_nside#
n_chan_in#
level_dims#
depth#
num_heads#
task = 'regression'#
out_channels = 1#
mlp_ratio#
KERNELSZ = 3#
gauge_type = 'cosmo'#
G = 1#
prefer_foscat_gpu = True#
cls_token_enabled = False#
pos_embed_type = 'learned'#
head_type = 'mean'#
dropout#
dtype = 'float32'#
token_down#
embed_dim#
cell_ids_fine#
f#
hconv_levels: List[foscat.SphericalStencil.SphericalStencil] = []#
level_cell_ids: List[numpy.ndarray]#
token_nside#
token_cell_ids#
hconv_token#
hconv_head#
patch_w#
patch_bn#
enc_w: torch.nn.ParameterList#
enc_bn: torch.nn.ModuleList#
n_tokens#
encoder#
token_proj#
dec_w: torch.nn.ParameterList#
dec_bn: torch.nn.ModuleList#
runtime_device#
forward(x: torch.Tensor, runtime_ids: numpy.ndarray | None = None) torch.Tensor[source]#

x: (B, Cin, Nfine), nested ordering runtime_ids: optional fine-level ids to decode onto (defaults to training ids)

predict(x: torch.Tensor | numpy.ndarray, batch_size: int = 8) torch.Tensor[source]#
foscat.healpix_vit_torch.in_nside = 4#