Source code for foscat.alm_latlon

"""
alm_latlon.py
=============
Spherical-harmonic transform for maps defined on an arbitrary
colatitude / longitude grid organised into rings.

Differences with respect to foscat.alm.alm
-----------------------------------------
- No dependency on HEALPix for pixel positioning.
- Rings may have arbitrary colatitudes (not HEALPix colatitudes)
  and arbitrary longitudes (not necessarily uniform).
- Longitude step: direct DFT when ring φ values are irregular;
  FFT + phase shift (like alm.comp_tf) when φ are uniformly spaced.
- Colatitude step: same Legendre recurrence as alm.compute_legendre_m,
  evaluated at the cosines of the provided colatitudes.
- Quadrature weights: trapezoidal in θ (sin θ · Δθ) × uniform in φ (2π/N_r)
  by default, or user-supplied weight array.

Main API
--------------
build_rings_from_latlon(lat, lon, atol)
    Groups a flat list of (lat, lon) into rings of equal colatitude.

compute_weights(ring_theta, ring_phi_list, ring_counts)
    Computes the quadrature weights per pixel (steradians).

comp_tf_latlon(im, ring_phi_list, ring_counts, pixel_weights, mmax)
    Ring-weighted DFT → ft[..., R, mmax+1].

map2alm_latlon(im, ring_theta, ring_phi_list, ring_counts, lmax, weights)
    Map → alm.

anafast_latlon(im, ring_theta, ring_phi_list, ring_counts, lmax, weights)
    Map → Cl.

alm2map_latlon(alm, ring_theta, ring_phi_list, ring_counts, lmax)
    alm → map (synthesis).

Minimal example
---------------
    import numpy as np
    from foscat.alm_latlon import alm_latlon

    # Regular grid (ntheta=64 rings × nphi=128 pixels per ring)
    ntheta, nphi = 64, 128
    theta_1d = np.linspace(np.pi / (2*ntheta), np.pi - np.pi/(2*ntheta), ntheta)
    phi_1d   = np.linspace(0, 2*np.pi*(1 - 1/nphi), nphi)

    lat  = np.repeat(theta_1d, nphi)        # colatitude of each pixel
    lon  = np.tile(phi_1d,     ntheta)      # longitude of each pixel

    ring_theta, ring_phi_list, ring_counts, sort_idx = \\
        alm_latlon.build_rings_from_latlon(lat, lon)

    obj  = alm_latlon(lmax=32)
    im   = np.random.randn(ntheta * nphi)

    alm_coeffs = obj.map2alm_latlon(
        im[sort_idx], ring_theta, ring_phi_list, ring_counts
    )
    cl = obj.anafast_latlon(
        im[sort_idx], ring_theta, ring_phi_list, ring_counts
    )
"""

import numpy as np
import torch

from foscat.alm import alm as _alm


[docs] class alm_latlon(_alm): """ Spherical-harmonic transform on an arbitrary lat/lon grid organised into rings. """ def __init__(self, backend=None, lmax=24, limit_range=1e10): # nside=None: no HEALPix grid, maxlog computed from lmax super().__init__(backend=backend, lmax=lmax, nside=None, limit_range=limit_range) # ================================================================== # # Build rings from a flat (lat, lon) array # # ================================================================== #
[docs] @staticmethod def build_rings_from_latlon(lat, lon, atol=1e-10, convention='colatitude_rad'): """ Group a flat list of pixels into rings of equal colatitude. Parameters ---------- lat : array [N] angular coordinate of each pixel (see convention). lon : array [N] longitudinal coordinate of each pixel (see convention). atol : tolerance in radians for grouping two pixels into the same ring. convention : str format of the input coordinates. 'colatitude_rad' (default) lat = colatitude θ in RADIANS 0 → π lon = longitude φ in RADIANS 0 → 2π 'colatitude_deg' lat = colatitude θ in DEGREES 0° → 180° lon = longitude φ in DEGREES 0° → 360° 'geographic_rad' lat = geographic latitude in RADIANS −π/2 → +π/2 lon = longitude in RADIANS −π → +π or 0 → 2π 'geographic_deg' lat = geographic latitude in DEGREES −90° → +90° lon = longitude in DEGREES −180° → +180° or 0° → 360° All conventions are converted internally to colatitude + longitude in radians before processing. Returns ------- ring_theta : ndarray [R] colatitude θ (radians) per ring ring_phi_list : list[ndarray [N_r]] longitudes φ (radians) per ring ring_counts : ndarray int64 [R] number of pixels per ring sort_idx : ndarray int64 [N] permutation im_sorted = im[sort_idx] """ lat = np.asarray(lat, dtype=np.float64).ravel() lon = np.asarray(lon, dtype=np.float64).ravel() conv = convention.lower().strip() if conv == 'colatitude_rad': theta = lat phi = lon elif conv == 'colatitude_deg': theta = np.radians(lat) phi = np.radians(lon) elif conv == 'geographic_rad': theta = np.pi / 2.0 - lat phi = lon % (2.0 * np.pi) elif conv == 'geographic_deg': theta = np.radians(90.0 - lat) phi = np.radians(lon) % (2.0 * np.pi) else: raise ValueError( f"Unknown convention: '{convention}'. " "Accepted values: 'colatitude_rad', 'colatitude_deg', " "'geographic_rad', 'geographic_deg'." ) N = len(theta) # Sort by colatitude then by longitude order = np.lexsort((phi, theta)) lat_s = theta[order] lon_s = phi[order] # Find ring boundaries (colatitude jump > atol) breaks = np.where(np.diff(lat_s) > atol)[0] + 1 ring_starts = np.concatenate([[0], breaks]) ring_ends = np.concatenate([breaks, [N]]) ring_theta = np.array([lat_s[s] for s in ring_starts], dtype=np.float64) ring_phi_list = [lon_s[s:e] for s, e in zip(ring_starts, ring_ends)] ring_counts = np.array([e - s for s, e in zip(ring_starts, ring_ends)], dtype=np.int64) return ring_theta, ring_phi_list, ring_counts, order
# ================================================================== # # Quadrature weights # # ================================================================== #
[docs] @staticmethod def compute_weights(ring_theta, ring_phi_list, ring_counts, quadrature='trapeze'): """ Computes the quadrature weights per pixel (steradians). Parameters ---------- ring_theta : [R] colatitudes in radians ring_phi_list : list[ndarray] longitudes per ring ring_counts : [R] number of pixels per ring quadrature : str quadrature method in θ. 'trapeze' (default) Trapezoidal rule: w_θ = sin(θ_r) × Δθ_r Suitable for regular θ grids. 'gauss_legendre' Exact Gauss-Legendre weights. Required for Gaussian grids (ERA5, ECMWF, IFS, ARPEGE…) where the colatitudes are the zeros of P_R(cos θ). The integral ∫f dΩ = ∫f dφ dx (x = cos θ) is then exact up to ℓ ≈ 2R-1. 'equal_area' Equal weights: 4π / N_total. For equal-area grids (HEALPix). Returns ------- weights : ndarray float64 [N_total] """ ring_theta = np.asarray(ring_theta, dtype=np.float64) ring_counts = np.asarray(ring_counts, dtype=np.int64) R = len(ring_theta) N_total = int(ring_counts.sum()) all_w = [] # ---- θ weights according to chosen method ---- if quadrature == 'trapeze': w_theta = np.empty(R, dtype=np.float64) for r in range(R): if R == 1: dth = np.pi elif r == 0: dth = (ring_theta[1] - ring_theta[0]) / 2.0 elif r == R - 1: dth = (ring_theta[-1] - ring_theta[-2]) / 2.0 else: dth = (ring_theta[r + 1] - ring_theta[r - 1]) / 2.0 w_theta[r] = abs(np.sin(ring_theta[r]) * dth) elif quadrature == 'gauss_legendre': # GL nodes are x_r = cos(θ_r) ∈ [-1, 1]. # np.polynomial.legendre.leggauss(R) returns nodes and weights # for ∫₋₁¹ f(x) dx ≈ Σ w_r f(x_r). # Since dΩ = dφ dx (with x = cos θ), the θ weights are directly # the GL weights (sin θ is absorbed into dx = -sin θ dθ). x_provided = np.cos(ring_theta) # provided nodes gl_nodes, gl_weights = np.polynomial.legendre.leggauss(R) # GL nodes are sorted in ascending order; cos θ is decreasing # (θ increasing), so we align them by sorting. sort_gl = np.argsort(gl_nodes) # -1 → +1 (ascending) sort_prov = np.argsort(x_provided) # cos θ ascending gl_w_sorted = gl_weights[sort_gl] # weights aligned on ascending x # Reorder to original ring order (θ ascending ≡ x descending) # sort_prov[i] = ring index of the i-th smallest cos θ w_theta = np.empty(R, dtype=np.float64) w_theta[sort_prov] = gl_w_sorted # Verification: GL nodes must match cos(θ_r) max_err = np.max(np.abs(np.sort(x_provided) - np.sort(gl_nodes))) if max_err > 1e-6: import warnings warnings.warn( f"gauss_legendre: provided colatitudes do not match " f"GL nodes (max error = {max_err:.2e}). " "Check that the grid is a Gaussian grid with " f"{R} latitude points.", UserWarning ) elif quadrature == 'equal_area': total_area = 4.0 * np.pi w_theta = np.full(R, total_area / N_total) # will be weighted by N_r below # Pour equal_area, le poids par pixel est uniforme : 4π/N_total weights = np.full(N_total, total_area / N_total, dtype=np.float64) return weights else: raise ValueError( f"Unknown quadrature: '{quadrature}'. " "Accepted values: 'trapeze', 'gauss_legendre', 'equal_area'." ) # ---- φ weights (common to both non-equal_area θ methods) ---- for r in range(R): N_r = int(ring_counts[r]) phi_r = np.asarray(ring_phi_list[r], dtype=np.float64) if N_r == 1: w_phi = np.array([2.0 * np.pi]) else: sorted_phi = np.sort(phi_r) dphi = np.diff(sorted_phi) if np.ptp(dphi) < 1e-10 * (2 * np.pi / N_r): w_phi = np.full(N_r, 2.0 * np.pi / N_r) else: gap_wrap = (sorted_phi[0] + 2.0 * np.pi) - sorted_phi[-1] dp_ext = np.concatenate([[gap_wrap], dphi, [gap_wrap]]) w_sorted = (dp_ext[:-1] + dp_ext[1:]) / 2.0 back = np.argsort(np.argsort(phi_r)) w_phi = w_sorted[back] all_w.append(w_theta[r] * w_phi) return np.concatenate(all_w)
# ================================================================== # # Fourier transform per ring # # ================================================================== # @staticmethod def _check_uniform(phi, tol=1e-10): """ Return (True, phi0, N) if the φ values are uniformly spaced at dphi = 2π/N, otherwise (False, None, None). The φ values need not be sorted. """ N = len(phi) if N <= 1: return True, float(phi[0]) if N == 1 else 0.0, N sorted_phi = np.sort(phi) dphi = np.diff(sorted_phi) mean_dp = 2.0 * np.pi / N if np.ptp(dphi) < tol * mean_dp: return True, float(sorted_phi[0]), N return False, None, None
[docs] def comp_tf_latlon(self, im, ring_phi_list, ring_counts, pixel_weights, mmax): """ Pixel-weighted DFT for each ring. For a ring with uniformly spaced φ values, uses the FFT + phase shift (same logic as alm.comp_tf). For an irregular ring, performs a direct DFT. Parameters ---------- im : [..., N_total] map (backend tensor or ndarray) ring_phi_list : list[ndarray] longitudes per ring ring_counts : ndarray int64 number of pixels per ring pixel_weights : ndarray float64 [N_total] quadrature weights mmax : int maximum frequency Returns ------- ft : tensor [..., R, mmax+1] complex """ R = len(ring_counts) m_vec = np.arange(mmax + 1, dtype=np.float64) im_bk = self.backend.bk_cast(im) out = [] offset = 0 for r in range(R): N_r = int(ring_counts[r]) phi_r = np.asarray(ring_phi_list[r], dtype=np.float64) w_r = pixel_weights[offset:offset + N_r] v = im_bk[..., offset:offset + N_r] # [..., N_r] offset += N_r is_unif, phi0, _ = self._check_uniform(phi_r) if is_unif: # ---- FFT + phase shift ---- # Sort v in ascending φ order sort_phi = np.argsort(phi_r) v_sorted = self.backend.bk_gather(v, sort_phi, axis=-1) # Uniform φ weights: w_r[j] = w_th * (2π/N_r) # Absorb the θ weight (constant per ring) as a scalar w_scalar = float(w_r[0]) # all φ weights equal for a uniform ring v_sorted = v_sorted * w_scalar # Real FFT → full spectrum tmp = self.rfft2fft(v_sorted) # [..., N_r] l_n = tmp.shape[-1] if l_n < mmax + 1: repeat_n = (mmax // l_n) + 1 tmp = self.backend.bk_tile(tmp, repeat_n, axis=-1) tmp = tmp[..., :mmax + 1] # [..., mmax+1] # Phase shift: exp(-i m phi0) for m = 0..mmax shift = np.exp(-1j * m_vec * phi0).astype(np.complex128) shift_bk = self.backend.bk_cast(shift) tmp = tmp * shift_bk else: # ---- Direct weighted DFT ---- # kernel[j, m] = w_r[j] * exp(-i m phi_r[j]) ang = np.outer(phi_r, m_vec) # [N_r, M] ker = (np.exp(-1j * ang) * w_r[:, None]) # [N_r, M] ker_bk = self.backend.bk_cast(ker.astype(np.complex128)) # ft[..., m] = sum_j v[..., j] * ker[j, m] tmp = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(v, axis=-1) * ker_bk, axis=-2 ) # [..., mmax+1] out.append(self.backend.bk_expand_dims(tmp, axis=-2)) # [..., 1, mmax+1] return self.backend.bk_concat(out, axis=-2) # [..., R, mmax+1]
# ================================================================== # # map → alm # # ================================================================== #
[docs] def map2alm_latlon(self, im, ring_theta, ring_phi_list, ring_counts, lmax=None, weights=None, quadrature='trapeze'): """ Compute the alm coefficients of a map defined on an arbitrary grid organised into rings. Parameters ---------- im : [..., N_total] map values, ordered ring by ring (use sort_idx from build_rings_from_latlon if the map is initially in arbitrary order) ring_theta : [R] colatitude (radians) of each ring ring_phi_list : list[ndarray] or ndarray [N_total] longitudes (radians) per ring (or flat array) ring_counts : [R] number of pixels per ring lmax : maximum multipole (default: self.lmax) weights : [N_total] per-pixel weights (steradians). If None, trapezoidal quadrature is computed automatically. Pass weights='uniform' to use 1/N_total (alm.map2alm convention). Returns ------- alm_out : [..., n_alm] complex, n_alm = Σ_{m=0}^{lmax} (lmax-m+1) Layout: [m=0: l=0..lmax | m=1: l=1..lmax | …] """ if lmax is None: lmax = self.lmax lmax = int(lmax) ring_theta = np.asarray(ring_theta, dtype=np.float64) ring_counts = np.asarray(ring_counts, dtype=np.int64) N_total = int(ring_counts.sum()) phi_list = self._parse_phi_list(ring_phi_list, ring_counts) # Handle batch dimension _added_batch = False if hasattr(im, 'ndim') and im.ndim == 1: im = im[None, :] _added_batch = True elif not hasattr(im, 'ndim') and len(im.shape) == 1: im = im[None, :] _added_batch = True # Quadrature weights if weights is None: pixel_weights = self.compute_weights(ring_theta, phi_list, ring_counts, quadrature=quadrature) elif isinstance(weights, str) and weights == 'uniform': pixel_weights = np.ones(N_total, dtype=np.float64) / N_total else: pixel_weights = np.asarray(weights, dtype=np.float64) # DFT per ring: ft[..., R, mmax+1] ft = self.comp_tf_latlon(im, phi_list, ring_counts, pixel_weights, mmax=lmax) # cos(θ) per ring for the Legendre recurrence co_th = np.cos(ring_theta) # [R] # Legendre projection alm_out = None for m in range(lmax + 1): # compute_legendre_m returns sqrt(4π) · P_lm^norm(cos θ), shape [L, R]. # Normalised spherical harmonics are Y_lm = sqrt((2l+1)/4π) · P_lm^norm. # The missing factor is sqrt(2l+1)/(4π); applied here so that # the projection gives a_lm = ∫ f Y_lm* dΩ (healpy/standard convention). plm = self.compute_legendre_m(co_th, m, lmax, nside=1) # [L, R] l_vals = np.arange(m, lmax + 1, dtype=np.float64) # [L] ylm_factor = np.sqrt(2.0 * l_vals + 1.0) / (4.0 * np.pi) # [L] plm = plm * ylm_factor[:, np.newaxis] # [L, R] plm_bk = self.backend.bk_cast(plm) # [L, R] ft_m = ft[..., :, m] # [..., R] # alm[..., l-m] = sum_r plm[l-m, r] * ft[..., r, m] # ft_m [..., 1, R] × plm [L, R] → sum_R → [..., L] tmp = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(ft_m, axis=-2) * plm_bk, axis=-1 ) # [..., L] if m == 0: alm_out = tmp else: alm_out = self.backend.bk_concat([alm_out, tmp], axis=-1) if _added_batch: alm_out = alm_out[0] return alm_out
# ================================================================== # # alm → map (synthesis) # # ================================================================== #
[docs] def alm2map_latlon(self, alm, ring_theta, ring_phi_list, ring_counts, lmax=None): """ Synthesis: reconstruct the map from alm coefficients. Parameters ---------- alm : [..., n_alm] harmonic coefficients ring_theta : [R] colatitudes (radians) ring_phi_list : list[ndarray] or ndarray [N_total] longitudes ring_counts : [R] number of pixels per ring lmax : maximum multipole used when computing the alm Returns ------- im_out : [..., N_total] reconstructed map """ if lmax is None: lmax = self.lmax lmax = int(lmax) ring_theta = np.asarray(ring_theta, dtype=np.float64) ring_counts = np.asarray(ring_counts, dtype=np.int64) phi_list = self._parse_phi_list(ring_phi_list, ring_counts) N_total = int(ring_counts.sum()) R = len(ring_theta) co_th = np.cos(ring_theta) # [R] # Build the Fourier part per ring: ft[r, m] = sum_l alm[l,m] * plm[l-m, r] # then reconstruct the map via inverse DFT _added_batch = False if hasattr(alm, 'ndim') and alm.ndim == 1: alm = alm[None, :] _added_batch = True # ft_synth[..., R, mmax+1] batch_shape = alm.shape[:-1] ft_synth = torch.zeros( batch_shape + (R, lmax + 1), dtype=torch.complex128, device=alm.device if hasattr(alm, 'device') else torch.device('cpu') ) idx = 0 for m in range(lmax + 1): L = lmax - m + 1 alm_m = alm[..., idx:idx + L] # [..., L] idx += L plm = self.compute_legendre_m(co_th, m, lmax, nside=1) # [L, R] plm_bk = self.backend.bk_cast(plm) # ft[..., r, m] = sum_l alm_m[..., l-m] * plm[l-m, r] # alm_m [..., L, 1] × plm [L, R] → sum_L → [..., R] contrib = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(alm_m, axis=-1) * plm_bk, axis=-2 ) # [..., R] ft_synth[..., :, m] = contrib # Inverse DFT per ring: im[r, j] = Re( sum_m ft[r,m] * exp(i m phi_j) ) out_per_ring = [] for r in range(R): N_r = int(ring_counts[r]) phi_r = np.asarray(phi_list[r], dtype=np.float64) m_vec = np.arange(lmax + 1, dtype=np.float64) # kernel[m, j] = exp(i m phi_j) [mmax+1, N_r] ang = np.outer(m_vec, phi_r) ker = np.exp(1j * ang).astype(np.complex128) # [M, N_r] ker_bk = self.backend.bk_cast(ker) ft_r = ft_synth[..., r, :] # [..., M] # im[r, j] = Re( sum_m ft[r,m] * exp(i m phi_j) ) # ft_r [..., M, 1] × ker [M, N_r] → sum_M → [..., N_r] pix = self.backend.bk_reduce_sum( self.backend.bk_expand_dims(ft_r, axis=-1) * ker_bk, axis=-2 ) # [..., N_r] out_per_ring.append(self.backend.bk_real(pix)) im_out = self.backend.bk_concat(out_per_ring, axis=-1) # [..., N_total] if _added_batch: im_out = im_out[0] return im_out
# ================================================================== # # map → Cl # # ================================================================== #
[docs] def anafast_latlon(self, im, ring_theta, ring_phi_list, ring_counts, lmax=None, weights=None, quadrature='trapeze'): """ Estimate the power spectrum Cl of a map on an arbitrary grid. Returns ------- cl : tensor [..., lmax+1] """ if lmax is None: lmax = self.lmax lmax = int(lmax) alm = self.map2alm_latlon( im, ring_theta, ring_phi_list, ring_counts, lmax=lmax, weights=weights, quadrature=quadrature ) batch_shape = alm.shape[:-1] if alm.ndim > 1 else () device = alm.device if hasattr(alm, 'device') else torch.device('cpu') cl = torch.zeros(batch_shape + (lmax + 1,), dtype=torch.float64, device=device) idx = 0 for m in range(lmax + 1): L = lmax - m + 1 a = alm[..., idx:idx + L] idx += L p = self.backend.bk_real(a * self.backend.bk_conjugate(a)) weight = 1.0 if m == 0 else 2.0 cl[..., m:m + L] += weight * p # Normalisation par (2l+1) denom = (2.0 * torch.arange(lmax + 1, dtype=torch.float64, device=device) + 1.0) denom = denom.reshape((1,) * len(batch_shape) + (lmax + 1,)) cl = cl / denom return cl
# ================================================================== # # Utilitaires internes # # ================================================================== # @staticmethod def _parse_phi_list(ring_phi_list, ring_counts): """ Accepte ring_phi_list comme : - liste de tableaux - flat array [N_total] (split according to ring_counts) Retourne toujours une liste de tableaux float64. """ ring_counts = np.asarray(ring_counts, dtype=np.int64) if isinstance(ring_phi_list, np.ndarray) and ring_phi_list.ndim == 1: splits = np.cumsum(ring_counts)[:-1] return [a.astype(np.float64) for a in np.split(ring_phi_list, splits)] return [np.asarray(p, dtype=np.float64) for p in ring_phi_list]
[docs] def grid_summary(self, ring_theta, ring_phi_list, ring_counts, lmax=None): """ Print a summary of the grid and the estimated computation cost. """ if lmax is None: lmax = self.lmax ring_counts = np.asarray(ring_counts, dtype=np.int64) N_total = int(ring_counts.sum()) R = len(ring_theta) n_unif = sum( 1 for r in range(R) if self._check_uniform(np.asarray(ring_phi_list[r]))[0] ) cost_fft = sum( (mmax_r + 1) * np.log2(max(2, int(ring_counts[r]))) for r, mmax_r in enumerate([lmax] * R) if self._check_uniform(np.asarray(ring_phi_list[r]))[0] ) cost_dft = sum( int(ring_counts[r]) * (lmax + 1) for r in range(R) if not self._check_uniform(np.asarray(ring_phi_list[r]))[0] ) print(f"=== Grid summary (lmax={lmax}) ===") print(f" Total pixels : {N_total}") print(f" Rings : {R}") print(f" Uniform rings : {n_unif}/{R} (FFT acceleration)") print(f" θ range : [{np.degrees(ring_theta.min()):.2f}°, " f"{np.degrees(ring_theta.max()):.2f}°]") print(f" N_pix/ring : min={ring_counts.min()}, " f"max={ring_counts.max()}, mean={ring_counts.mean():.1f}") print(f" n_alm : {sum(lmax - m + 1 for m in range(lmax + 1))}")