Source code for foscat.healpix_vit_torch

# healpix_vit_varlevels.py
# HEALPix ViT with level-wise (variable) channel widths and U-Net-style spherical decoder
from __future__ import annotations
from typing import List, Optional, Literal, Tuple, Union
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import foscat.scat_cov as sc
import foscat.SphericalStencil as ho


[docs] class HealpixViT(nn.Module): """ HEALPix Vision Transformer (Foscat-based) with *variable channel widths per level* and a U-Net-like spherical decoder. Key idea -------- - Encoder uses a list of channel dimensions `level_dims = [C_fine, C_l1, ..., C_token]` that evolve *with depth* (e.g., 128 -> 192 -> 256). - At each encoder level (before a Down()), we apply a spherical convolution that maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level. - Transformer runs at the token grid with embedding dim = C_token. - Decoder upsamples *one level at a time*; after each Up() it concatenates the upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies a spherical convolution to fuse (C_{i+1} + C_i) -> C_i. - Final head maps C_fine -> out_channels at the finest grid. Shapes (dense tasks) -------------------- Input : (B, Cin, Nfine) → patch-embed (Cin -> C_fine) at finest grid → for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i) → tokens at grid L with dim C_token → Transformer on tokens (C_token) → for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i → Head: C_fine -> out_channels at finest grid Requirements ------------ - level_dims length must be token_down+1, with: len(level_dims) = token_down + 1 level_dims[0] = channels at finest grid after patch embedding level_dims[-1] = Transformer embedding dimension - Each value in level_dims must be divisible by G (number of gauges). - out_channels must be divisible by G. Parameters (main) ----------------- in_nside : input HEALPix nside (nested) n_chan_in : input channels at finest grid (Cin) level_dims : list of ints, channel width per level from fine to token depth : number of Transformer encoder layers num_heads : self-attention heads cell_ids : finest-level nested indices (Nfine = 12*nside^2) task : "regression" | "segmentation" | "global" out_channels : output channels for dense tasks KERNELSZ : spherical kernel size for Foscat convolutions gauge_type : "cosmo" | "phi" G : number of gauges """ def __init__( self, *, in_nside: int, n_chan_in: int, level_dims: List[int], # e.g., [128, 192, 256] (fine -> token) depth: int, num_heads: int, cell_ids: np.ndarray, task: Literal["regression", "segmentation", "global"] = "regression", out_channels: int = 1, mlp_ratio: float = 4.0, KERNELSZ: int = 3, gauge_type: Literal["cosmo", "phi"] = "cosmo", G: int = 1, prefer_foscat_gpu: bool = True, cls_token: bool = False, pos_embed: Literal["learned", "none"] = "learned", head_type: Literal["mean", "cls"] = "mean", dropout: float = 0.0, dtype: Literal["float32", "float64"] = "float32", ) -> None: super().__init__() # ---- config ---- self.in_nside = int(in_nside) self.n_chan_in = int(n_chan_in) self.level_dims = list(level_dims) self.depth = int(depth) self.num_heads = int(num_heads) self.task = task self.out_channels = int(out_channels) self.mlp_ratio = float(mlp_ratio) self.KERNELSZ = int(KERNELSZ) self.gauge_type = gauge_type self.G = int(G) self.prefer_foscat_gpu = bool(prefer_foscat_gpu) self.cls_token_enabled = bool(cls_token) self.pos_embed_type = pos_embed self.head_type = head_type self.dropout = float(dropout) self.dtype = dtype if len(self.level_dims) < 1: raise ValueError("level_dims must have at least one element (fine level).") self.token_down = len(self.level_dims) - 1 self.embed_dim = int(self.level_dims[-1]) # Transformer dim for d in self.level_dims: if d % self.G != 0: raise ValueError(f"Each level dim must be divisible by G={self.G}; got {d}.") if self.embed_dim % self.num_heads != 0: raise ValueError("embed_dim must be divisible by num_heads.") if dtype == "float32": self.np_dtype = np.float32 self.torch_dtype = torch.float32 else: self.np_dtype = np.float64 self.torch_dtype = torch.float32 # keep model in fp32 if cell_ids is None: raise ValueError("cell_ids (finest) must be provided (nested ordering).") self.cell_ids_fine = np.asarray(cell_ids) # Default activation if self.task == "segmentation": self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax" else: self.final_activation = "none" # Foscat wrapper self.f = sc.funct(KERNELSZ=self.KERNELSZ) # ---- Build operators per level (fine -> ... -> token) and compute ids ---- self.hconv_levels: List[ho.SphericalStencil] = [] self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine] current_nside = self.in_nside dummy = self.f.backend.bk_cast( np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype) ) for _ in range(self.token_down): hc = ho.SphericalStencil( current_nside, self.KERNELSZ, n_gauges=self.G, gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1], dtype=self.torch_dtype, ) self.hconv_levels.append(hc) dummy, next_ids = hc.Down( dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True ) self.level_cell_ids.append(self.f.backend.to_numpy(next_ids)) current_nside //= 2 self.token_nside = current_nside if self.token_down > 0 else self.in_nside self.token_cell_ids = self.level_cell_ids[-1] # Token and fine-level operators (for convenience) self.hconv_token = ho.SphericalStencil( self.token_nside, self.KERNELSZ, n_gauges=self.G, gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype, ) self.hconv_head = ho.SphericalStencil( self.in_nside, self.KERNELSZ, n_gauges=self.G, gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype, ) # ---------------- Patch embedding (Cin -> C_fine) ---------------- fine_dim = self.level_dims[0] fine_g = fine_dim // self.G self.patch_w = nn.Parameter( torch.empty(self.n_chan_in, fine_g, self.KERNELSZ * self.KERNELSZ) ) nn.init.kaiming_uniform_(self.patch_w.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5)) self.patch_bn = nn.GroupNorm(num_groups=min(8, fine_dim if fine_dim > 1 else 1), num_channels=fine_dim) # ---------------- Encoder convs per level (C_i -> C_{i+1}) ---------------- self.enc_w: nn.ParameterList = nn.ParameterList() self.enc_bn: nn.ModuleList = nn.ModuleList() for i in range(self.token_down): Cin = self.level_dims[i] Cout = self.level_dims[i+1] Cout_g = Cout // self.G w = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ * self.KERNELSZ)) nn.init.kaiming_uniform_(w.view(Cin * Cout_g, -1), a=np.sqrt(5)) self.enc_w.append(w) self.enc_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1), num_channels=Cout)) # ---------------- Transformer at token grid ---------------- self.n_tokens = int(self.token_cell_ids.shape[0]) if self.cls_token_enabled: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) n_pe = self.n_tokens + 1 else: self.cls_token = None n_pe = self.n_tokens if self.pos_embed_type == "learned": self.pos_embed = nn.Parameter(torch.zeros(1, n_pe, self.embed_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) else: self.pos_embed = None enc_layer = nn.TransformerEncoderLayer( d_model=self.embed_dim, nhead=self.num_heads, dim_feedforward=int(self.embed_dim * self.mlp_ratio), dropout=self.dropout, activation="gelu", batch_first=True, norm_first=True, ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.depth) # Projection at token grid (keep C_token) self.token_proj = nn.Linear(self.embed_dim, self.embed_dim) # ---------------- Decoder convs per level ( (C_{i+1}+C_i) -> C_i ) ---------------- self.dec_w: nn.ParameterList = nn.ParameterList() self.dec_bn: nn.ModuleList = nn.ModuleList() for i in range(self.token_down - 1, -1, -1): # decoder proceeds from token level back to fine; we create weights in the same order Cin_fuse = self.level_dims[i+1] + self.level_dims[i] # up + skip Cout = self.level_dims[i] Cout_g = Cout // self.G w = nn.Parameter(torch.empty(Cin_fuse, Cout_g, self.KERNELSZ * self.KERNELSZ)) nn.init.kaiming_uniform_(w.view(Cin_fuse * Cout_g, -1), a=np.sqrt(5)) self.dec_w.append(w) # index 0 corresponds to up from token to level L-1 self.dec_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1), num_channels=Cout)) # ---------------- Final head (C_fine -> out_channels) ---------------- if self.task == "global": self.global_head = nn.Linear(self.embed_dim, self.out_channels) else: self.C_fine = self.level_dims[0] if self.out_channels % self.G != 0: raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}") out_g = self.C_fine//self.G self.head_w = nn.Parameter(torch.empty(out_g, self.out_channels, self.KERNELSZ * self.KERNELSZ)) nn.init.kaiming_uniform_(self.head_w.view(self.out_channels * out_g, -1), a=np.sqrt(5)) self.head_bn = (nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels > 1 else 1), num_channels=self.out_channels) if self.task == "segmentation" else None) # ---------------- Device probe ---------------- pref = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.runtime_device = self._probe_and_set_runtime_device(pref) # ---------------- device helpers ---------------- def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None: for name, val in list(vars(hc).items()): try: if torch.is_tensor(val): setattr(hc, name, val.to(device)) elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]): setattr(hc, name, type(val)([v.to(device) for v in val])) except Exception: pass @torch.no_grad() def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device: if preferred.type == "cuda" and self.prefer_foscat_gpu: try: super().to(preferred) for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]: self._move_hc(hc, preferred) # dry run npix0 = int(self.cell_ids_fine.shape[0]) x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred) hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head y_try = hc0.Convol_torch(x_try, self.patch_w, cell_ids=self.cell_ids_fine) _ = y_try.sum().item() self._foscat_device = preferred return preferred except Exception as e: self._gpu_probe_error = repr(e) cpu = torch.device("cpu") super().to(cpu) for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]: self._move_hc(hc, cpu) self._foscat_device = cpu return cpu def _to_numpy_ids(self, ids): """Return ids as a NumPy array on CPU (handles torch.Tensor on CUDA).""" if torch.is_tensor(ids): return ids.detach().cpu().numpy() return np.asarray(ids) # ---------------- helpers ---------------- def _as_tensor_batch(self, x): if isinstance(x, list): if len(x) == 1: t = x[0] return t.unsqueeze(0) if t.dim() == 2 else t raise ValueError("Variable-length list not supported here; pass a tensor.") return x # ---------------- forward ----------------
[docs] def forward( self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None, ) -> torch.Tensor: """ x: (B, Cin, Nfine), nested ordering runtime_ids: optional fine-level ids to decode onto (defaults to training ids) """ if not isinstance(x, torch.Tensor): raise TypeError("x must be a torch.Tensor") if x.dim() != 3: raise ValueError("Input must be (B, Cin, Npix)") if x.shape[1] != self.n_chan_in: raise ValueError(f"Expected {self.n_chan_in} channels, got {x.shape[1]}") if runtime_ids is not None: runtime_ids = self._to_numpy_ids(runtime_ids) x = x.to(self.runtime_device) # -------- Patch embedding Cin -> C_fine -------- hc_fine0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head z = hc_fine0.Convol_torch(x, self.patch_w, cell_ids=self.cell_ids_fine) # (B, C_fine, Nfine) if not torch.is_tensor(z): z = torch.as_tensor(z, device=self.runtime_device) z = self._as_tensor_batch(z) z = self.patch_bn(z) z = F.gelu(z) # -------- Encoder path: for each level i: EncConv(C_i->C_{i+1}) then Down() -------- skips: List[torch.Tensor] = [] ids_list: List[np.ndarray] = [] l_data = z l_cell_ids = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids) current_nside = self.in_nside for i, hc in enumerate(self.hconv_levels): # save skip BEFORE going down (channels = C_i, grid = current level) skips.append(self._as_tensor_batch(l_data)) ids_list.append(self._to_numpy_ids(l_cell_ids)) # conv to next channels C_{i+1} at same grid w_enc = self.enc_w[i] l_data = hc.Convol_torch(l_data, w_enc, cell_ids=l_cell_ids) # (B, C_{i+1}, N_current) if not torch.is_tensor(l_data): l_data = torch.as_tensor(l_data, device=self.runtime_device) l_data = self._as_tensor_batch(l_data) l_data = self.enc_bn[i](l_data) l_data = F.gelu(l_data) # Down one level l_data, l_cell_ids = hc.Down(l_data, cell_ids=l_cell_ids, nside=current_nside, max_poll=True) l_data = self._as_tensor_batch(l_data) current_nside //= 2 # We are now at token grid with channels = C_token x_tok = l_data # (B, C_token, Ntok) token_ids = l_cell_ids # ids at token level assert x_tok.shape[1] == self.embed_dim, "Token channels mismatch with embed_dim." # -------- Transformer on tokens -------- seq = x_tok.permute(0, 2, 1) # (B, Ntok, E) if self.cls_token_enabled: cls = self.cls_token.expand(seq.size(0), -1, -1) seq = torch.cat([cls, seq], dim=1) if self.pos_embed is not None: seq = seq + self.pos_embed[:, :seq.shape[1], :] seq = self.encoder(seq) # (B, Ntok(+1), E) if self.cls_token_enabled: tokens = seq[:, 1:, :] # drop cls for dense else: tokens = seq tok_feat = self.token_proj(tokens).permute(0, 2, 1) # (B, C_token, Ntok) if self.task == "global": if self.head_type == "cls" and self.cls_token_enabled: cls_vec = seq[:, 0, :] return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(cls_vec) else: return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(tokens.mean(dim=1)) # -------- Build runtime id chain (fine -> ... -> token) -------- fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids) ids_chain = [np.asarray(fine_ids_runtime)] nside_tmp = self.in_nside _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype)) for hc in self.hconv_levels: _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True) ids_chain.append(self.f.backend.to_numpy(_next)) nside_tmp //= 2 tok_ids_np = self._to_numpy_ids(token_ids) assert tok_feat.shape[-1] == tok_ids_np.shape[0], "Token count mismatch." assert np.array_equal(tok_ids_np, ids_chain[-1]), "Token ids mismatch with runtime chain." # list of nsides at each encoder level (fine -> ... -> pre-token) nsides_levels = [self.in_nside // (2 ** k) for k in range(self.token_down)] # -------- Decoder: Up step-by-step with fusion conv -------- y = tok_feat # (B, C_token, Ntok) dec_idx = 0 # index in self.dec_w / self.dec_bn (built from token->fine order) for i in range(len(ids_chain)-1, 0, -1): coarse_ids = ids_chain[i] # current y grid fine_ids = ids_chain[i-1] # target grid source_ns = self.in_nside // (2 ** i) fine_ns = self.in_nside // (2 ** (i-1)) # choose operator for the target fine level if fine_ns == self.in_nside: op_fine = self.hconv_head else: idx = nsides_levels.index(fine_ns) op_fine = self.hconv_levels[idx] # Up one level y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns) if not torch.is_tensor(y_up): y_up = torch.as_tensor(y_up, device=self.runtime_device) y_up = self._as_tensor_batch(y_up) # (B, C_{i}, N_fine) # Skip at this level (channels = C_{i-1}) skip_i = self._as_tensor_batch(skips[i-1]).to(y_up.device) assert np.array_equal(np.asarray(ids_list[i-1]), np.asarray(fine_ids)), "Skip ids misaligned." # Concat and fuse: (C_{i} + C_{i-1}) -> C_{i-1} y_cat = torch.cat([y_up, skip_i], dim=1) y = op_fine.Convol_torch(y_cat, self.dec_w[dec_idx], cell_ids=fine_ids) if not torch.is_tensor(y): y = torch.as_tensor(y, device=self.runtime_device) y = self._as_tensor_batch(y) y = self.dec_bn[dec_idx](y) y = F.gelu(y) if self.dropout > 0: y = F.dropout(y, p=self.dropout, training=self.training) dec_idx += 1 # y is now (B, C_fine, Nfine) # -------- Final head to out_channels -------- y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime) if not torch.is_tensor(y): y = torch.as_tensor(y, device=self.runtime_device) y = self._as_tensor_batch(y) if self.task == "segmentation" and self.head_bn is not None: y = self.head_bn(y) if self.final_activation == "sigmoid": y = torch.sigmoid(y) elif self.final_activation == "softmax": y = torch.softmax(y, dim=1) return y
[docs] @torch.no_grad() def predict(self, x: Union[torch.Tensor, np.ndarray], batch_size: int = 8) -> torch.Tensor: self.eval() if isinstance(x, np.ndarray): x = torch.from_numpy(x).float() outs = [] for i in range(0, x.shape[0], batch_size): xb = x[i : i + batch_size].to(self.runtime_device) outs.append(self.forward(xb)) return torch.cat(outs, dim=0)
# -------------------------- Smoke test -------------------------- if __name__ == "__main__": # nside=4 → Npix=192, 2 down levels → token_nside=1 in_nside = 4 npix = 12 * in_nside * in_nside cell_ids = np.arange(npix, dtype=np.int64) B, Cin = 2, 3 x = torch.randn(B, Cin, npix) # Channel widths per level (fine -> token), divisible by G=1 here level_dims = [64, 96, 128] model = HealpixViTVarLevels( in_nside=in_nside, n_chan_in=Cin, level_dims=level_dims, # len=3 => token_down=2 depth=2, num_heads=4, cell_ids=cell_ids, task="regression", out_channels=1, KERNELSZ=3, G=1, cls_token=False, dropout=0.1, ).eval() with torch.no_grad(): y = model(x) print("Output:", y.shape) # (B, Cout, Nfine)