Source code for foscat.alm_loc


import numpy as np
import healpy as hp

from foscat.alm import alm as _alm
import torch

[docs] class alm_loc(_alm): """ Local/partial-sky variant of foscat.alm.alm. Key design choice (to match alm.py exactly when full-sky is provided): - Reuse *all* Legendre/normalization machinery from the parent class (alm), i.e. shift_ph(), compute_legendre_m(), ratio_mm, A/B recurrences, etc. This is critical for matching alm.map2alm() numerically. Differences vs alm.map2alm(): - Input map is [..., n] with explicit (nside, cell_ids) - Only rings touched by cell_ids are processed. - For rings with full coverage, we run the exact same FFT+tiling logic as alm.comp_tf() (but only for those rings) -> bitwise comparable up to backend FFT differences. - For rings with partial coverage, we compute a *partial DFT* for m=0..mmax, using the same phase convention as alm.comp_tf(): FFT kernel uses exp(-i 2pi (m mod Nring) j / Nring) then apply the per-ring shift exp(-i m phi0) via self.matrix_shift_ph """ def __init__(self, backend=None, lmax=24, limit_range=1e10): super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range) # --------- helpers: ring layout identical to alm.ring_th/ring_ph ---------- @staticmethod def _ring_starts_sizes(nside: int): starts = [] sizes = [] n = 0 for k in range(nside - 1): N = 4 * (k + 1) starts.append(n); sizes.append(N) n += N for _ in range(2 * nside + 1): N = 4 * nside starts.append(n); sizes.append(N) n += N for k in range(nside - 1): N = 4 * (nside - 1 - k) starts.append(n); sizes.append(N) n += N return np.asarray(starts, np.int64), np.asarray(sizes, np.int32) def _to_ring_ids(self, nside: int, cell_ids: np.ndarray, nest: bool) -> np.ndarray: if nest: return hp.nest2ring(nside, cell_ids) return cell_ids def _group_by_ring(self, nside: int, ring_ids: np.ndarray): """ Returns: ring_idx: ring number (0..4*nside-2) per pixel pos: position along ring (0..Nring-1) per pixel order: sort order grouping by ring then pos starts,sizes: ring layout """ starts, sizes = self._ring_starts_sizes(nside) # ring index = last start <= ring_id ring_idx = np.searchsorted(starts, ring_ids, side="right") - 1 ring_idx = ring_idx.astype(np.int32) pos = (ring_ids - starts[ring_idx]).astype(np.int32) order = np.lexsort((pos, ring_idx)) return ring_idx, pos, order, starts, sizes # ------------------ local Fourier transform per ring ---------------------
[docs] def comp_tf_loc(self, im, nside: int, cell_ids, nest: bool = False, realfft: bool = True, mmax=None): """ Returns: rings_used: 1D np.ndarray of ring indices present ft: backend tensor of shape [..., nrings_used, mmax+1] (complex) where last axis is m, ring axis matches rings_used order. """ nside = int(nside) cell_ids = np.asarray(cell_ids, dtype=np.int64) if mmax is None: mmax = min(self.lmax, 3 * nside - 1) mmax = int(mmax) # Ensure parent caches for this nside exist (matrix_shift_ph, A/B, ratio_mm, etc.) self.shift_ph(nside) ring_ids = self._to_ring_ids(nside, cell_ids, nest) ring_idx, pos, order, starts, sizes = self._group_by_ring(nside, ring_ids) ring_idx = ring_idx[order] pos = pos[order] i_im = self.backend.bk_cast(im) i_im = self.backend.bk_gather(i_im, order, axis=-1) # reorder last axis rings_used, start_ptr, counts = np.unique(ring_idx, return_index=True, return_counts=True) # Build output per ring as list then concat out_per_ring = [] for r, s0, cnt in zip(rings_used.tolist(), start_ptr.tolist(), counts.tolist()): Nring = int(sizes[r]) p = pos[s0:s0+cnt] v = self.backend.bk_gather(i_im, np.arange(s0, s0+cnt, dtype=np.int64), axis=-1) if cnt == Nring: # Full ring: exact same FFT+tiling logic as alm.comp_tf for 1 ring # Need data ordered by pos (already grouped, but ensure pos is 0..N-1) if not np.all(p == np.arange(Nring, dtype=p.dtype)): # reorder within ring sub_order = np.argsort(p) v = self.backend.bk_gather(v, sub_order, axis=-1) if realfft: tmp = self.rfft2fft(v) else: tmp = self.backend.bk_fft(v) l_n = tmp.shape[-1] if l_n < mmax + 1: repeat_n = (mmax // l_n) + 1 tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1) tmp = tmp[..., :mmax+1] # Apply per-ring shift exp(-i m phi0) exactly like alm.comp_tf shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] tmp = tmp * shift out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m] else: # Partial ring: partial DFT for required m, using same aliasing as FFT branch m_vec = np.arange(mmax+1, dtype=np.int64) m_mod = (m_vec % Nring).astype(np.int64) # angles: 2pi * pos * m_mod / Nring ang = (2.0 * np.pi / Nring) * p.astype(np.float64)[:, None] * m_mod[None, :].astype(np.float64) ker = np.exp(-1j * ang).astype(np.complex128) # [cnt, m] ker_bk = self.backend.bk_cast(ker) # v is [..., cnt]; we want [..., m] = sum_cnt v*ker tmp = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(v, axis=-1) * ker_bk, axis=-2 ) # [..., m] shift = self.matrix_shift_ph[nside][r, :mmax+1] # [m] true m shift tmp = tmp * shift out_per_ring.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [...,1,m] ft = self.backend.bk_concat(out_per_ring, axis=-2) # [..., nrings, m] return np.asarray(rings_used, dtype=np.int32), ft
# ---------------------------- map -> alm --------------------------------
[docs] def map2alm_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None): nside = int(nside) if lmax is None: lmax = min(self.lmax, 3 * nside - 1) lmax = int(lmax) # Ensure a batch dimension like alm.map2alm expects _added_batch = False if hasattr(im, 'ndim') and im.ndim == 1: im = im[None, :] _added_batch = True elif (not hasattr(im, 'ndim')) and len(im.shape) == 1: im = im[None, :] _added_batch = True rings_used, ft = self.comp_tf_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, realfft=True, mmax=lmax) # cos(theta) on used rings co_th = np.cos(self.ring_th(nside)[rings_used]) # ft is [..., R, m] alm_out = None for m in range(lmax + 1): # IMPORTANT: reuse alm.compute_legendre_m and its normalization exactly plm = self.compute_legendre_m(co_th, m, lmax, nside) / (12 * nside**2) # [L,R] plm_bk = self.backend.bk_cast(plm) ft_m = ft[..., :, m] # [..., R] tmp = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk, axis=-1 ) # [..., L] l_vals = np.arange(m, lmax + 1, dtype=np.float64) scale = np.sqrt(2.0 * l_vals + 1.0) # convertir scale en backend tensor (torch) sur le bon device scale_t = self.backend.bk_cast(scale) # or an equivalent helper # reshape for broadcast if needed: [1, L] or [L] shape = (1,) * (tmp.ndim - 1) + (scale_t.shape[0],) scale_t = scale_t.reshape(shape) tmp = tmp * scale_t if m == 0: alm_out = tmp else: alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1) if _added_batch: alm_out = alm_out[0] return alm_out
# ---------------------------- alm -> Cl ---------------------------------
[docs] def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None): if lmax is None: lmax = min(self.lmax, 3 * nside - 1) lmax = int(lmax) alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax) # cl has same batch dims as alm, plus ell dim batch_shape = alm.shape[:-1] cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=alm.device) idx = 0 for m in range(lmax + 1): L = lmax - m + 1 a = alm[..., idx:idx+L] # shape: batch + (L,) idx += L p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # batch + (L,) if m == 0: cl[..., m:] += p else: cl[..., m:] += 2.0 * p # divide by (2l+1), broadcast over batch dims denom = (2 * torch.arange(lmax + 1, dtype=cl.dtype, device=alm.device) + 1) # (lmax+1,) denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) # batch-broadcast cl = cl / denom return cl
''' def anafast_loc(self, im, nside: int, cell_ids, nest: bool = False, lmax=None): if lmax is None: lmax = min(self.lmax, 3 * nside - 1) lmax = int(lmax) alm = self.map2alm_loc(im, nside=nside, cell_ids=cell_ids, nest=nest, lmax=lmax) # Unpack and compute Cl with correct real-field folding: cl = torch.zeros((lmax + 1,), dtype=alm.dtype, device=alm.device) idx = 0 for m in range(lmax + 1): L = lmax - m + 1 a = alm[..., idx:idx+L] idx += L p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) # sum over any batch dims p = self.backend.bk_reduce_sum(p, axis=tuple(range(p.ndim-1))) if p.ndim > 1 else p if m == 0: cl[m:] += p else: cl[m:] += 2.0 * p denom = (2*torch.arange(lmax+1,dtype=p.dtype, device=alm.device)+1) cl = cl / denom return cl '''