Source code for foscat.SphereDownGeo


import torch
import torch.nn as nn
import numpy as np
import healpy as hp


[docs] class SphereDownGeo(nn.Module): """ Geometric HEALPix downsampling operator (NESTED indexing). This module reduces resolution by a factor 2: nside_out = nside_in // 2 Input conventions ----------------- - If in_cell_ids is None: x is expected to be full-sphere: [B, C, N_in] output is [B, C, K_out] with K_out = len(cell_ids_out) (or N_out if None). - If in_cell_ids is provided (fine pixels at nside_in, NESTED): x can be either: * compact: [B, C, K_in] where K_in = len(in_cell_ids), aligned with in_cell_ids order * full-sphere: [B, C, N_in] (also supported) output is [B, C, K_out] where cell_ids_out is derived as unique(in_cell_ids // 4), unless you explicitly pass cell_ids_out (then it will be intersected with the derived set). Modes ----- - mode="smooth": linear downsampling y = M @ x (M sparse) - mode="maxpool": non-linear max over available children (fast) """ def __init__( self, nside_in: int, mode: str = "smooth", radius_deg: float | None = None, sigma_deg: float | None = None, weight_norm: str = "l1", cell_ids_out: np.ndarray | list[int] | None = None, in_cell_ids: np.ndarray | list[int] | torch.Tensor | None = None, use_csr=True, device=None, dtype: torch.dtype = torch.float32, ): super().__init__() if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device self.dtype = dtype self.nside_in = int(nside_in) assert (self.nside_in & (self.nside_in - 1)) == 0, "nside_in must be a power of 2." self.nside_out = self.nside_in // 2 assert self.nside_out >= 1, "nside_out must be >= 1." self.N_in = 12 * self.nside_in * self.nside_in self.N_out = 12 * self.nside_out * self.nside_out self.mode = str(mode).lower() assert self.mode in ("smooth", "maxpool"), "mode must be 'smooth' or 'maxpool'." self.weight_norm = str(weight_norm).lower() assert self.weight_norm in ("l1", "l2"), "weight_norm must be 'l1' or 'l2'." # ---- Handle reduced-domain inputs (fine pixels) ---- self.in_cell_ids = self._validate_in_cell_ids(in_cell_ids) self.has_in_subset = self.in_cell_ids is not None if self.has_in_subset: # derive parents derived_out = np.unique(self.in_cell_ids // 4).astype(np.int64) if cell_ids_out is None: self.cell_ids_out = derived_out else: req_out = self._validate_cell_ids_out(cell_ids_out) # keep only those compatible with derived_out (otherwise they'd be all-zero) self.cell_ids_out = np.intersect1d(req_out, derived_out, assume_unique=False) if self.cell_ids_out.size == 0: raise ValueError( "After intersecting cell_ids_out with unique(in_cell_ids//4), " "no coarse pixel remains. Check your inputs." ) else: self.cell_ids_out = self._validate_cell_ids_out(cell_ids_out) self.K_out = int(self.cell_ids_out.size) # Column basis for smooth matrix: # - full sphere: columns are 0..N_in-1 # - subset: columns are 0..K_in-1 aligned to self.in_cell_ids self.K_in = int(self.in_cell_ids.size) if self.has_in_subset else self.N_in if self.mode == "smooth": if radius_deg is None: # default: include roughly the 4 children footprint # (healpy pixel size ~ sqrt(4pi/N), coarse pixel is 4x area) radius_deg = 2.0 * hp.nside2resol(self.nside_out, arcmin=True) / 60.0 if sigma_deg is None: sigma_deg = max(radius_deg / 2.0, 1e-6) self.radius_deg = float(radius_deg) self.sigma_deg = float(sigma_deg) self.radius_rad = self.radius_deg * np.pi / 180.0 self.sigma_rad = self.sigma_deg * np.pi / 180.0 M = self._build_down_matrix() # shape (K_out, K_in or N_in) self.M = M if use_csr: self.M = self.M.to_sparse_csr().to(self.device) self.M_size = M.size() else: # Precompute children indices for maxpool # For subset mode, store mapping from each parent to indices in compact vector, # with -1 for missing children. children = np.stack( [4 * self.cell_ids_out + i for i in range(4)], axis=1, ).astype(np.int64) # [K_out, 4] in fine pixel ids (full indexing) if self.has_in_subset: # map each child pixel id to position in in_cell_ids (compact index) pos = self._positions_in_sorted(self.in_cell_ids, children.reshape(-1)) children_compact = pos.reshape(self.K_out, 4).astype(np.int64) # -1 if missing self.register_buffer( "children_compact", torch.tensor(children_compact, dtype=torch.long, device=self.device), ) else: self.register_buffer( "children_full", torch.tensor(children, dtype=torch.long, device=self.device), ) # expose ids as torch buffers for convenience self.register_buffer( "cell_ids_out_t", torch.tensor(self.cell_ids_out.astype(np.int64), dtype=torch.long, device=self.device), ) if self.has_in_subset: self.register_buffer( "in_cell_ids_t", torch.tensor(self.in_cell_ids.astype(np.int64), dtype=torch.long, device=self.device), ) # ---------------- validation helpers ---------------- def _validate_cell_ids_out(self, cell_ids_out): """Return a 1D np.int64 array of coarse cell ids (nside_out).""" if cell_ids_out is None: return np.arange(self.N_out, dtype=np.int64) arr = np.asarray(cell_ids_out, dtype=np.int64).reshape(-1) if arr.size == 0: raise ValueError("cell_ids_out is empty: provide at least one coarse pixel id.") arr = np.unique(arr) if arr.min() < 0 or arr.max() >= self.N_out: raise ValueError(f"cell_ids_out must be in [0, {self.N_out-1}] for nside_out={self.nside_out}.") return arr def _validate_in_cell_ids(self, in_cell_ids): """Return a 1D np.int64 array of fine cell ids (nside_in) or None.""" if in_cell_ids is None: return None if torch.is_tensor(in_cell_ids): arr = in_cell_ids.detach().cpu().numpy() else: arr = np.asarray(in_cell_ids) arr = np.asarray(arr, dtype=np.int64).reshape(-1) if arr.size == 0: raise ValueError("in_cell_ids is empty: provide at least one fine pixel id or None.") arr = np.unique(arr) if arr.min() < 0 or arr.max() >= self.N_in: raise ValueError(f"in_cell_ids must be in [0, {self.N_in-1}] for nside_in={self.nside_in}.") return arr @staticmethod def _positions_in_sorted(sorted_ids: np.ndarray, query_ids: np.ndarray) -> np.ndarray: """ For each query_id, return its index in sorted_ids if present, else -1. sorted_ids must be sorted ascending unique. """ q = np.asarray(query_ids, dtype=np.int64) idx = np.searchsorted(sorted_ids, q) ok = (idx >= 0) & (idx < sorted_ids.size) & (sorted_ids[idx] == q) out = np.full(q.shape, -1, dtype=np.int64) out[ok] = idx[ok] return out # ---------------- weights and matrix build ---------------- def _normalize_weights(self, w: np.ndarray) -> np.ndarray: w = np.asarray(w, dtype=np.float64) if w.size == 0: return w w = np.maximum(w, 0.0) if self.weight_norm == "l1": s = w.sum() if s <= 0.0: return np.ones_like(w) / max(w.size, 1) return w / s # l2 s2 = (w * w).sum() if s2 <= 0.0: return np.ones_like(w) / max(np.sqrt(w.size), 1.0) return w / np.sqrt(s2) def _build_down_matrix(self) -> torch.Tensor: nside_in = self.nside_in nside_out = self.nside_out sigma = float(self.sigma_rad) p_out = self.cell_ids_out.astype(np.int64) # [K] K = p_out.size offs = np.arange(4, dtype=np.int64) # [4] # --- (A) Neighbourhood choice on the coarse side # Option 1 (minimal, very fast): parent only -> 4 children #parents = p_out[:, None] # [K,1] # Option 2 (plus “lisse”) : parent + 8 voisins -> 9 parents -> 36 enfants neigh8 = hp.get_all_neighbours(nside_out, p_out, nest=True) # [8,K] (healpy renvoie souvent [8,K]) parents = np.concatenate([p_out[None, :], neigh8], axis=0).T # [K,9] idx=np.where(parents==-1) parents[idx[0],idx[1]]=parents[idx[0],idx[1]-1] # --- enfants fins (NESTED) : child_id = 4*parent + {0,1,2,3} children = (4 * parents[..., None] + offs[None, None, :]).reshape(K, -1) # [K, 4] ou [K,36] # If neighbour option active: invalidate children of parents=-1 #mask_child = np.repeat(mask_parent, 4, axis=1) # [K,36] #children_flat = children[mask_child] #print(mask_child.shape,children.shape,K) rows_flat = np.repeat(np.arange(K, dtype=np.int64), children.shape[1])#[mask_child.ravel()] # Option minimal (sans voisins) : children_flat = children.reshape(-1) # [K*4] #rows_flat = np.repeat(np.arange(K, dtype=np.int64), children.shape[1]) # --- Subset: map vers indices compacts if self.has_in_subset: in_ids = self.in_cell_ids # sorted/unique idx = np.searchsorted(in_ids, children_flat) in_range = idx < in_ids.size # idx est toujours >=0 pour searchsorted idx2 = idx[in_range] child2 = children_flat[in_range] ok2 = (in_ids[idx2] == child2) # comparaison safe ok = np.zeros_like(in_range, dtype=bool) ok[in_range] = ok2 cols_flat = idx[ok] rows_flat2 = rows_flat[ok] child_ids_kept = children_flat[ok] else: cols_flat = children_flat rows_flat2 = rows_flat child_ids_kept = children_flat if rows_flat2.size == 0: indices = torch.zeros((2, 0), dtype=torch.long, device=self.device) vals_t = torch.zeros((0,), dtype=self.dtype, device=self.device) return torch.sparse_coo_tensor(indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype).coalesce() # --- Gaussian weights (vectorised) # centres coarse: vec0 [K,3] vx0, vy0, vz0 = hp.pix2vec(nside_out, p_out, nest=True) vec0 = np.stack([vx0, vy0, vz0], axis=1) # [K,3] # vector of kept children vx, vy, vz = hp.pix2vec(nside_in, child_ids_kept, nest=True) vec = np.stack([vx, vy, vz], axis=1) # [nnz,3] # dot(vec, vec0[row]) dots = np.einsum("ij,ij->i", vec, vec0[rows_flat2]) dots = np.clip(dots, -1.0, 1.0) ang = np.arccos(dots) w = np.exp(-2.0 * (ang / max(sigma, 1e-30)) ** 2) # --- normalisation par ligne (row-wise), sans boucle # On recompose un tableau dense “temporaire” [K, m] via indexation m = children.shape[1] # position within row (0..m-1) jpos = np.tile(np.arange(m, dtype=np.int64), K) if self.has_in_subset: jpos = jpos.reshape(-1)[ok] # aligned with rows_flat2/child_ids_kept # si option voisins + mask_child : jpos = jpos[mask_child.reshape(-1)][ok] etc. W = np.zeros((K, m), dtype=np.float64) W[rows_flat2, jpos] = w if self.weight_norm == "l1": s = W.sum(axis=1, keepdims=True) s[s <= 0] = 1.0 W /= s else: # l2 s2 = np.sqrt((W * W).sum(axis=1, keepdims=True)) s2[s2 <= 0] = 1.0 W /= s2 # extract w normalised to the same nnz w_norm = W[rows_flat2, jpos].astype(np.float32) # --- sparse rows_t = torch.tensor(rows_flat2, dtype=torch.long, device=self.device) cols_t = torch.tensor(cols_flat, dtype=torch.long, device=self.device) vals_t = torch.tensor(w_norm, dtype=self.dtype, device=self.device) indices = torch.stack([rows_t, cols_t], dim=0) return torch.sparse_coo_tensor(indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype).coalesce() ''' def _build_down_matrix(self) -> torch.Tensor: """Construct sparse matrix M (K_out, K_in or N_in) for the selected coarse pixels.""" nside_in = self.nside_in nside_out = self.nside_out radius_rad = self.radius_rad sigma_rad = self.sigma_rad rows: list[int] = [] cols: list[int] = [] vals: list[float] = [] # For subset columns, we use self.in_cell_ids as the basis subset_cols = self.has_in_subset in_ids = self.in_cell_ids # np.ndarray or None for r, p_out in enumerate(self.cell_ids_out.tolist()): theta0, phi0 = hp.pix2ang(nside_out, int(p_out), nest=True) vec0 = hp.ang2vec(theta0, phi0) neigh = hp.query_disc(nside_in, vec0, radius_rad, inclusive=True, nest=True) neigh = np.asarray(neigh, dtype=np.int64) if subset_cols: # keep only valid fine pixels # neigh is not sorted; intersect1d expects sorted neigh_sorted = np.sort(neigh) keep = np.intersect1d(neigh_sorted, in_ids, assume_unique=False) neigh = keep # Fallback: if radius query returns nothing in subset mode, at least try the 4 children if neigh.size == 0: children = (4 * int(p_out) + np.arange(4, dtype=np.int64)) if subset_cols: pos = self._positions_in_sorted(in_ids, children) ok = pos >= 0 if np.any(ok): neigh = children[ok] else: # nothing to connect -> row stays zero continue else: neigh = children theta, phi = hp.pix2ang(nside_in, neigh, nest=True) vec = hp.ang2vec(theta, phi) # angular distance via dot product dots = np.clip(np.dot(vec, vec0), -1.0, 1.0) ang = np.arccos(dots) w = np.exp(- 2.0*(ang / sigma_rad) ** 2) w = self._normalize_weights(w) if subset_cols: pos = self._positions_in_sorted(in_ids, neigh) # all should be present due to filtering, but guard anyway ok = pos >= 0 neigh_pos = pos[ok] w = w[ok] if neigh_pos.size == 0: continue for c, v in zip(neigh_pos.tolist(), w.tolist()): rows.append(r) cols.append(int(c)) vals.append(float(v)) else: for c, v in zip(neigh.tolist(), w.tolist()): rows.append(r) cols.append(int(c)) vals.append(float(v)) if len(rows) == 0: # build an all-zero sparse tensor indices = torch.zeros((2, 0), dtype=torch.long, device=self.device) vals_t = torch.zeros((0,), dtype=self.dtype, device=self.device) return torch.sparse_coo_tensor( indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype ).coalesce() rows_t = torch.tensor(rows, dtype=torch.long, device=self.device) cols_t = torch.tensor(cols, dtype=torch.long, device=self.device) vals_t = torch.tensor(vals, dtype=self.dtype, device=self.device) indices = torch.stack([rows_t, cols_t], dim=0) M = torch.sparse_coo_tensor( indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype, ).coalesce() return M ''' # ---------------- forward ----------------
[docs] def forward(self, x: torch.Tensor): """ Parameters ---------- x : torch.Tensor If has_in_subset: - [B,C,K_in] (compact, aligned with in_cell_ids) OR [B,C,N_in] (full sphere) Else: - [B,C,N_in] (full sphere) Returns ------- y : torch.Tensor [B,C,K_out] cell_ids_out : torch.Tensor [K_out] coarse pixel ids (nside_out), aligned with y last dimension. """ if x.dim() != 3: raise ValueError("x must be [B, C, N]") B, C, N = x.shape if self.has_in_subset: if N not in (self.K_in, self.N_in): raise ValueError( f"x last dim must be K_in={self.K_in} (compact) or N_in={self.N_in} (full), got {N}" ) else: if N != self.N_in: raise ValueError(f"x last dim must be N_in={self.N_in}, got {N}") if self.mode == "smooth": # If x is full-sphere but M is subset-based, gather compact inputs if self.has_in_subset and N == self.N_in: x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device)) else: x_use = x # sparse mm expects 2D: (K_out, K_in) @ (K_in, B*C) x2 = x_use.reshape(B * C, -1).transpose(0, 1).contiguous() y2 = torch.sparse.mm(self.M, x2) y = y2.transpose(0, 1).reshape(B, C, self.K_out).contiguous() return y, self.cell_ids_out_t.to(x.device) # maxpool if self.has_in_subset and N == self.N_in: x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device)) else: x_use = x if self.has_in_subset: # children_compact: [K_out, 4] indices in 0..K_in-1 or -1 ch = self.children_compact.to(x.device) # [K_out,4] # gather with masking # We build y by iterating 4 children with max y = None for j in range(4): idx = ch[:, j] # [K_out] mask = idx >= 0 # start with very negative so missing children don't win tmp = torch.full((B, C, self.K_out), -torch.inf, device=x.device, dtype=x.dtype) if mask.any(): tmp[:, :, mask] = x_use.index_select(dim=2, index=idx[mask]).reshape(B, C, -1) y = tmp if y is None else torch.maximum(y, tmp) # If a parent had no valid children at all, it is -inf -> set to 0 y = torch.where(torch.isfinite(y), y, torch.zeros_like(y)) return y, self.cell_ids_out_t.to(x.device) else: ch = self.children_full.to(x.device) # [K_out,4] full indices # gather children and max xch = x_use.index_select(dim=2, index=ch.reshape(-1)).reshape(B, C, self.K_out, 4) y = xch.max(dim=3).values return y, self.cell_ids_out_t.to(x.device)