# SPDX-License-Identifier: MIT
# Author: J.-M. Delouis
import numpy as np
import healpy as hp
import torch
[docs]
class SphericalStencil:
"""
GPU-accelerated spherical stencil operator for HEALPix convolutions.
This class implements three phases:
A) Geometry preparation: build local rotated stencil vectors for each target
pixel, compute HEALPix neighbor indices and interpolation weights.
B) Sparse binding: map neighbor indices/weights to available data samples
(sorted ids), and normalize weights.
C) Convolution: apply multi-channel kernels to sparse gathered data.
Once A+B are prepared, multiple convolutions (C) can be applied efficiently
on the GPU.
Parameters
----------
nside : int
HEALPix resolution parameter.
kernel_sz : int
Size of local stencil (must be odd, e.g. 3, 5, 7).
gauge_type : str
Type of gauge :
'cosmo' use the same definition than
https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
'phi' is define at the pole, could be better for earth observation not using intensivly the pole
n_gauge : float
Number of oriented gauges (Default 1).
blend : bool
Whether to blend smoothly between axisA and axisB (dual gauge).
power : float
Sharpness of blend transition (dual gauge).
nest : bool
Use nested ordering if True (default), else ring ordering.
cell_ids : np.ndarray | torch.Tensor | None
If given, initialize Step A immediately for these targets.
device : torch.device | str | None
Default device (if None, 'cuda' if available else 'cpu').
dtype : torch.dtype | None
Default dtype (float32 if None).
"""
def __init__(
self,
nside: int,
kernel_sz: int,
*,
nest: bool = True,
cell_ids=None,
device=None,
dtype=None,
n_gauges=1,
gauge_type='phi',
):
assert kernel_sz >= 1 and int(kernel_sz) == kernel_sz
assert kernel_sz % 2 == 1, "kernel_sz must be odd"
self.nside = int(nside)
self.KERNELSZ = int(kernel_sz)
self.P = self.KERNELSZ * self.KERNELSZ
self.G = n_gauges
self.gauge_type=gauge_type
self.nest = bool(nest)
# Torch defaults
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if dtype is None:
dtype = torch.float32
self.device = torch.device(device)
self.dtype = dtype
# Geometry cache
self.Kb = None
self.idx_t = None # (4, K*P) neighbor indices
self.w_t = None # (4, K*P) interpolation weights
self.ids_sorted_np = None
self.pos_safe_t = None
self.w_norm_t = None
self.present_t = None
# Optional: keep a copy of the default ids if provided
self.cell_ids_default = None
# ---- Optional immediate preparation (Step A+B at init) ----
if cell_ids is not None:
# Keep a copy of the default target grid (fast-path later)
cid = np.asarray(cell_ids, dtype=np.int64).reshape(-1)
self.cell_ids_default = cid.copy()
# Step A (Torch): build geometry for this grid with G gauges
th, ph = hp.pix2ang(self.nside, cid, nest=self.nest)
self.prepare_torch(th, ph, G=self.G) # fills idx_t/_multi and w_t/_multi
# Step B (Torch): bind sparse mapping on the class device/dtype
order = np.argsort(cid)
self.ids_sorted_np = cid[order] # cache for fast-path
if self.G > 1:
# Multi-gauge binding (produces pos_safe_t_multi, w_norm_t_multi)
self.bind_support_torch_multi(
self.ids_sorted_np,
device=self.device,
dtype=self.dtype,
)
else:
# Single-gauge binding (produces pos_safe_t, w_norm_t)
self.bind_support_torch(
self.ids_sorted_np,
device=self.device,
dtype=self.dtype,
)
# ------------------------------------------------------------------
# Rotation construction in Torch
# ------------------------------------------------------------------
@staticmethod
def _rotation_total_torch(th, ph, alpha=None, G: int = 1, gauge_cosmo=True,device=None, dtype=None):
"""
Build a batch of rotation matrices with *G gauges* per target.
Column-vector convention: v' = R @ v.
Parameters
----------
th : array-like (N,)
Colatitude.
ph : array-like (N,)
Longitude.
alpha : array-like (N,) or scalar or None
Base gauge rotation angle around the local normal.
If None -> 0. For each gauge g in [0..G-1], we add g*pi/G.
G : int
Number of gauges to generate per target (>=1).
device, dtype : torch device/dtype
Returns
-------
R_tot : torch.Tensor, shape (N, G, 3, 3)
For each target i and gauge g, the matrix:
R_tot[i,g] = R_gauge(alpha[i] + g*pi/G) @ Rz(ph[i]) @ Ry(th[i])
"""
assert G >= 1, "G must be >= 1"
# ---- to torch 1D
th = torch.as_tensor(th, device=device, dtype=dtype).view(-1)
ph = torch.as_tensor(ph, device=device, dtype=dtype).view(-1)
if alpha is None:
alpha = torch.zeros_like(th)
else:
alpha = torch.as_tensor(alpha, device=device, dtype=dtype).view(-1)
device = th.device
dtype = th.dtype
N = th.shape[0]
# ---- base rotation R_base = Rz(ph) @ Ry(th), shape (N,3,3)
ct, st = torch.cos(th), torch.sin(th)
cp, sp = torch.cos(ph), torch.sin(ph)
R_base = torch.zeros((N, 3, 3), device=device, dtype=dtype)
# row 0
R_base[:, 0, 0] = cp * ct
R_base[:, 0, 1] = -sp
R_base[:, 0, 2] = cp * st
# row 1
R_base[:, 1, 0] = sp * ct
R_base[:, 1, 1] = cp
R_base[:, 1, 2] = sp * st
# row 2
R_base[:, 2, 0] = -st
R_base[:, 2, 1] = 0.0
R_base[:, 2, 2] = ct
# local normal n = third column of R_base, shape (N,3)
n = R_base[:, :, 2]
n = n / torch.linalg.norm(n, dim=1, keepdim=True).clamp_min(1e-12) # safe normalize
# per-target sign: +1 if th <= pi/2 else -1
sign = torch.where(th <= (np.pi/2), torch.ones_like(th), -torch.ones_like(th)) # (N,)
# base gauge shifts (always positive)
g_shifts = torch.arange(G, device=device, dtype=dtype) * (np.pi / G) # (G,)
# broadcast with sign: (N,G)
if gauge_cosmo:
alpha_g = alpha[:, None] + sign[:, None] * g_shifts[None, :]
else:
alpha_g = alpha[:, None] + g_shifts[None, :]
ca = torch.cos(alpha_g) # (N,G)
sa = torch.sin(alpha_g) # (N,G)
# ---- expand normal to (N,G,3)
n_g = n[:, None, :].expand(N, G, 3) # (N,G,3)
nx, ny, nz = n_g[..., 0], n_g[..., 1], n_g[..., 2]
# skew-symmetric K(n_g), shape (N,G,3,3)
K = torch.zeros((N, G, 3, 3), device=device, dtype=dtype)
K[..., 0, 1] = -nz; K[..., 0, 2] = ny
K[..., 1, 0] = nz; K[..., 1, 2] = -nx
K[..., 2, 0] = -ny; K[..., 2, 1] = nx
# outer(n,n) and identity
outer = n_g.unsqueeze(-1) * n_g.unsqueeze(-2) # (N,G,3,3)
I = torch.eye(3, device=device, dtype=dtype).view(1,1,3,3).expand(N, G, 3, 3)
# ---- Rodrigues per gauge: R_gauge(N,G,3,3)
R_gauge = I * ca.view(N, G, 1, 1) + K * sa.view(N, G, 1, 1) + \
outer * (1.0 - ca).view(N, G, 1, 1)
# ---- broadcast multiply with base: R_base_g(N,G,3,3)
R_base_g = R_base.unsqueeze(1).expand(N, G, 3, 3)
R_tot = torch.matmul(R_gauge, R_base_g) # (N,G,3,3)
return R_tot
# ------------------------------------------------------------------
# Torch-based get_interp_weights wrapper
# ------------------------------------------------------------------
[docs]
@staticmethod
def get_interp_weights_from_vec_torch(
nside: int,
vec,
*,
nest: bool = True,
device=None,
dtype=None,
chunk_size=1_000_000,
):
"""
Torch wrapper for healpy.get_interp_weights using input vectors.
Parameters
----------
nside : int
HEALPix resolution.
vec : torch.Tensor (...,3)
Direction vectors (not necessarily normalized).
nest : bool
Nested ordering if True (default).
device, dtype : Torch device/dtype.
chunk_size : int
Number of points per healpy call on CPU.
Returns
-------
idx_t : LongTensor (4, *leading)
w_t : Tensor (4, *leading)
"""
if not isinstance(vec, torch.Tensor):
vec = torch.as_tensor(vec, device=device, dtype=dtype)
else:
device = vec.device if device is None else device
dtype = vec.dtype if dtype is None else dtype
vec = vec.to(device=device, dtype=dtype)
orig_shape = vec.shape[:-1]
M = int(np.prod(orig_shape)) if len(orig_shape) else 1
v = vec.reshape(M, 3)
eps = torch.finfo(vec.dtype).eps
r = torch.linalg.norm(v, dim=1, keepdim=True).clamp_min(eps)
v_unit = v / r
x, y, z = v_unit[:, 0], v_unit[:, 1], v_unit[:, 2]
theta = torch.acos(z.clamp(-1.0, 1.0))
phi = torch.atan2(y, x)
two_pi = torch.tensor(2*np.pi, device=device, dtype=dtype)
phi = (phi % two_pi)
theta_np = theta.detach().cpu().numpy()
phi_np = phi.detach().cpu().numpy()
idx_accum, w_accum = [], []
for start in range(0, M, chunk_size):
stop = min(start + chunk_size, M)
t_chunk, p_chunk = theta_np[start:stop], phi_np[start:stop]
idx_np, w_np = hp.get_interp_weights(nside, t_chunk, p_chunk, nest=nest)
idx_accum.append(idx_np)
w_accum.append(w_np)
idx_np_all = np.concatenate(idx_accum, axis=1) if len(idx_accum) > 1 else idx_accum[0]
w_np_all = np.concatenate(w_accum, axis=1) if len(w_accum) > 1 else w_accum[0]
idx_t = torch.as_tensor(idx_np_all, device=device, dtype=torch.long)
w_t = torch.as_tensor(w_np_all, device=device, dtype=dtype)
if len(orig_shape):
idx_t = idx_t.view(4, *orig_shape)
w_t = w_t.view(4, *orig_shape)
return idx_t, w_t
# ------------------------------------------------------------------
# Step A: geometry preparation fully in Torch
# ------------------------------------------------------------------
[docs]
def prepare_torch(self, th, ph, alpha=None, G: int = 1):
"""
Prepare rotated stencil and HEALPix neighbors/weights in Torch for *G gauges*.
Parameters
----------
th, ph : array-like, shape (K,)
Target colatitudes/longitudes.
alpha : array-like (K,) or scalar or None
Base gauge angle about the local normal at each target. If None -> 0.
For each gauge g in [0..G-1], the effective angle is alpha + g*pi/G.
G : int (>=1)
Number of gauges to generate per target.
Side effects
------------
Sets:
- self.Kb = K
- self.G = G
- self.idx_t_multi : (G, 4, K*P) LongTensor (neighbors per gauge)
- self.w_t_multi : (G, 4, K*P) Tensor (weights per gauge)
- For backward compat when G==1:
self.idx_t : (4, K*P)
self.w_t : (4, K*P)
Returns
-------
idx_t_multi : torch.LongTensor, shape (G, 4, K*P)
w_t_multi : torch.Tensor, shape (G, 4, K*P)
"""
# --- sanitize inputs on CPU (angles) then use class device/dtype
th = np.asarray(th, float).reshape(-1)
ph = np.asarray(ph, float).reshape(-1)
K = th.size
self.Kb = K
self.G = int(G)
assert self.G >= 1, "G must be >= 1"
# --- build the local (P,3) stencil once on device
P = self.P
vec_np = np.zeros((P, 3), dtype=float)
grid = (np.arange(self.KERNELSZ) - self.KERNELSZ // 2)
# NEW: angular offsets
xx,yy=np.meshgrid(grid,grid)
s=1.0 # could be modified
alpha_pix = hp.nside2resol(self.nside, arcmin=False) # ~ taille angulaire typique
dtheta = (np.sqrt(xx**2+yy**2) * alpha_pix * s).ravel()
dphi = (np.arctan2(yy,xx)).ravel()
# local spherical displacement
# convert to unit vectors
x = np.sin(dtheta) * np.cos(dphi)
y = np.sin(dtheta) * np.sin(dphi)
z = np.cos(dtheta)
#print(self.nside*x.reshape(self.KERNELSZ,self.KERNELSZ))
#print(self.nside*y.reshape(self.KERNELSZ,self.KERNELSZ))
#print(self.nside*z.reshape(self.KERNELSZ,self.KERNELSZ))
vec_np = np.stack([x, y, z], axis=-1)
#vec_np[:, 0] = np.tile(grid, self.KERNELSZ)
#vec_np[:, 1] = np.repeat(grid, self.KERNELSZ)
#vec_np[:, 2] = 1.0 - np.sqrt(vec_np[:, 0]**2 + vec_np[:, 1]**2)
vec_t = torch.as_tensor(vec_np, device=self.device, dtype=self.dtype) # (P,3)
# --- rotation matrices for all targets & gauges: (K,G,3,3)
if alpha is None:
if self.gauge_type=='cosmo':
alpha=2*((th>np.pi/2)-0.5)*ph
else:
alpha=0.0*th
R_t = self._rotation_total_torch(
th, ph, alpha, G=self.G, gauge_cosmo=(self.gauge_type=='cosmo'),
device=self.device, dtype=self.dtype
) # shape (K,G,3,3)
# --- rotate stencil for each (target, gauge): (K,G,P,3)
# einsum over local stencil (P,3) with rotation (K,G,3,3)
rotated = torch.einsum('kgij,pj->kgpi', R_t, vec_t) # (K,G,P,3)
# --- query HEALPix (neighbors+weights) in one call over (K*G*P)
rotated_flat = rotated.reshape(-1, 3) # (K*G*P, 3)
idx_t, w_t = self.get_interp_weights_from_vec_torch(
self.nside,
rotated_flat,
nest=self.nest,
device=self.device,
dtype=self.dtype,
) # each (4, K*G*P)
# --- reshape back to split gauges:
# current: (4, K*G*P) -> (4, K, G, P) -> (G, 4, K, P) -> (G, 4, K*P)
idx_t = idx_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
w_t = w_t.view(4, K, self.G, P).permute(2, 0, 1, 3).reshape(self.G, 4, K*P)
# --- cache multi-gauge versions
self.idx_t_multi = idx_t # (G, 4, K*P)
self.w_t_multi = w_t # (G, 4, K*P)
# --- backward compatibility: when G==1, also fill single-gauge fields
if self.G == 1:
self.idx_t = idx_t[0] # (4, K*P)
self.w_t = w_t[0] # (4, K*P)
else:
# when multi-gauge, you can pick a default (e.g., gauge 0) if legacy code asks
# but better to adapt bind/apply to consume the multi-gauge tensors.
self.idx_t = None
self.w_t = None
return self.idx_t_multi, self.w_t_multi
[docs]
def bind_support_torch_multi(self, ids_sorted_np, *, device=None, dtype=None):
"""
Multi-gauge sparse binding (Step B) WITH 'reduced domain' logic:
- weights of out-of-domain neighbours set to 0
- column renormalisation to 1
- si colonne vide: fallback sur le pixel cible (centre du stencil)
Produit:
self.pos_safe_t_multi : (G, 4, K*P)
self.w_norm_t_multi : (G, 4, K*P)
self.present_t_multi : (G, 4, K*P)
"""
assert hasattr(self, 'idx_t_multi') and self.idx_t_multi is not None, \
"Call prepare_torch(..., G>0) before bind_support_torch_multi(...)"
assert hasattr(self, 'w_t_multi') and self.w_t_multi is not None
if device is None: device = self.device
if dtype is None: dtype = self.dtype
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64).reshape(-1)
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
G, _, M = self.idx_t_multi.shape
K = self.Kb
P = self.P
assert M == K*P, "idx_t_multi second axis must have K*P columns"
# index du centre du stencil (en flatten P)
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1) # ex. 5 -> 12
pos_list, present_list, wnorm_list = [], [], []
for g in range(G):
idx = self.idx_t_multi[g].to(device=device, dtype=torch.long) # (4, M)
w = self.w_t_multi[g].to(device=device, dtype=dtype) # (4, M)
# positions dans ids_sorted
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
in_range = pos < ids_sorted.numel()
cmp_vals = torch.full_like(idx, -1)
cmp_vals[in_range] = ids_sorted[pos[in_range]]
present = (cmp_vals == idx) # (4, M) bool
# Columns with NO neighbour present
empty_cols = ~present.any(dim=0) # (M,)
if empty_cols.any():
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
ref_cols = k_id * P + p_ref
src = ref_cols[empty_cols]
# copie idx/w de la colonne 'centre'
idx[:, empty_cols] = idx[:, src]
w[:, empty_cols] = w[:, src]
# --- Recompute presence/pos safely on those columns
idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
valid_e = pos_e < ids_sorted.numel()
pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
cmp_e = ids_sorted[pos_e_clipped]
present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
present[:, empty_cols] = present_e.view(4, -1)
pos[:, empty_cols] = pos_e_clipped.view(4, -1)
# Zero out absent weights then renormalise to 1 per column
w = w * present
colsum = w.sum(dim=0, keepdim=True)
zero_cols = (colsum == 0)
if zero_cols.any():
w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
colsum = w.sum(dim=0, keepdim=True)
w_norm = w / colsum.clamp_min(1e-12)
pos_safe = torch.where(present, pos, torch.zeros_like(pos))
pos_list.append(pos_safe)
present_list.append(present)
wnorm_list.append(w_norm)
self.pos_safe_t_multi = torch.stack(pos_list, dim=0) # (G, 4, M)
self.present_t_multi = torch.stack(present_list, dim=0) # (G, 4, M)
self.w_norm_t_multi = torch.stack(wnorm_list, dim=0) # (G, 4, M)
# miroir device/dtype runtime
self.device = device
self.dtype = dtype
[docs]
def bind_support_torch(self, ids_sorted_np, *, device=None, dtype=None):
"""
Single-gauge sparse binding (Step B) WITH 'reduced domain' logic:
- weights of out-of-domain neighbours set to 0
- column renormalisation to 1
- si colonne vide: fallback sur le pixel cible (centre du stencil)
"""
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
self.ids_sorted_np = np.asarray(ids_sorted_np, dtype=np.int64)
ids_sorted = torch.as_tensor(self.ids_sorted_np, device=device, dtype=torch.long)
idx = self.idx_t.to(device=device, dtype=torch.long) # (4, K*P)
w = self.w_t.to(device=device, dtype=dtype) # (4, K*P)
K = self.Kb
P = self.P
M = K * P
# positions dans ids_sorted
pos = torch.searchsorted(ids_sorted, idx.reshape(-1)).view(4, M)
in_range = pos < ids_sorted.shape[0]
cmp_vals = torch.full_like(idx, -1)
cmp_vals[in_range] = ids_sorted[pos[in_range]]
present = (cmp_vals == idx) # (4, M)
# Fallback colonnes vides -> centre du stencil
p_ref = (self.KERNELSZ // 2) * (self.KERNELSZ + 1)
empty_cols = ~present.any(dim=0) # (M,)
if empty_cols.any():
k_id = torch.div(torch.arange(M, device=device), P, rounding_mode='floor') # (M,)
ref_cols = k_id * P + p_ref
src = ref_cols[empty_cols]
# copie idx/w de la colonne 'centre'
idx[:, empty_cols] = idx[:, src]
w[:, empty_cols] = w[:, src]
# --- Recompute presence/pos safely on those columns
idx_e = idx[:, empty_cols].reshape(-1) # (4*M_empty,)
pos_e = torch.searchsorted(ids_sorted, idx_e) # (4*M_empty,)
# valid positions strictly inside [0, len)
valid_e = pos_e < ids_sorted.numel()
pos_e_clipped = pos_e.clamp_max(max(ids_sorted.numel()-1, 0)).to(torch.long)
cmp_e = ids_sorted[pos_e_clipped]
present_e = valid_e & (cmp_e == idx_e) # (4*M_empty,)
# reshape back
present[:, empty_cols] = present_e.view(4, -1)
pos[:, empty_cols] = pos_e_clipped.view(4, -1)
# Zero absent weights + renormalise to 1
w = w * present
colsum = w.sum(dim=0, keepdim=True)
zero_cols = (colsum == 0)
if zero_cols.any():
# force 1 on the first available row (here row 0)
w[0, zero_cols[0]] = present[0, zero_cols[0]].to(w.dtype)
colsum = w.sum(dim=0, keepdim=True)
w_norm = w / colsum.clamp_min(1e-12)
self.pos_safe_t = torch.where(present, pos, torch.zeros_like(pos))
self.w_norm_t = w_norm
self.present_t = present
self.device = device
self.dtype = dtype
# ------------------------------------------------------------------
# Step C: apply convolution (already Torch in your code)
# ------------------------------------------------------------------
[docs]
def apply_multi(self, data_sorted_t: torch.Tensor, kernel_t: torch.Tensor):
"""
Apply multi-gauge convolution.
Inputs
------
data_sorted_t : (B, Ci, K) torch.Tensor on self.device/self.dtype
kernel_t : either
- (Ci, Co_g, P) : shared kernel for all gauges
- (G, Ci, Co_g, P) : per-gauge kernels
Returns
-------
out : (B, G*Co_g, K) torch.Tensor
"""
assert hasattr(self, 'pos_safe_t_multi') and self.pos_safe_t_multi is not None, \
"Call bind_support_torch_multi(...) before apply_multi(...)"
B, Ci, K = data_sorted_t.shape
G, _, M = self.pos_safe_t_multi.shape
assert M == K * self.P
# normalize kernel to per-gauge
if kernel_t.dim() == 3:
Ci_k, Co_g, P = kernel_t.shape
assert Ci_k == Ci and P == self.P
kernel_g = kernel_t[None, ...].expand(G, -1, -1, -1) # (G, Ci, Co_g, P)
elif kernel_t.dim() == 4:
Gk, Ci_k, Co_g, P = kernel_t.shape
assert Gk == G and Ci_k == Ci and P == self.P
kernel_g = kernel_t
else:
raise ValueError("kernel_t must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
outs = []
for g in range(G):
pos_safe = self.pos_safe_t_multi[g] # (4, K*P)
w_norm = self.w_norm_t_multi[g] # (4, K*P)
# gather four neighbors then weight -> (B,Ci,K,P)
vals_g = []
for j in range(4):
vj = data_sorted_t.index_select(2, pos_safe[j].reshape(-1)) # (B,Ci,K*P)
vj = vj.view(B, Ci, K, self.P)
vals_g.append(vj * w_norm[j].view(1, 1, K, self.P))
tmp = sum(vals_g) # (B,Ci,K,P)
# spatial+channel mixing with kernel of this gauge -> (B,Co_g,K)
yg = torch.einsum('bckp,cop->bok', tmp, kernel_g[g])
outs.append(yg)
# concat the gauges along channel dimension: (B, G*Co_g, K)
return torch.cat(outs, dim=1)
[docs]
def apply(self, data_sorted_t, kernel_t):
"""
Apply the (Ci,Co,P) kernel to batched sparse data (B,Ci,K)
using precomputed pos_safe and w_norm. Runs fully on GPU.
Parameters
----------
data_sorted_t : torch.Tensor (B,Ci,K)
Input data aligned with ids_sorted.
kernel_t : torch.Tensor (Ci,Co,P)
Convolution kernel.
Returns
-------
out : torch.Tensor (B,Co,K)
"""
assert self.pos_safe_t is not None and self.w_norm_t is not None
B, Ci, K = data_sorted_t.shape
Ci_k, Co, P = kernel_t.shape
assert Ci_k == Ci and P == self.P
vals = []
for j in range(4):
vj = data_sorted_t.index_select(2, self.pos_safe_t[j].reshape(-1))
vj = vj.view(B, Ci, K, P)
vals.append(vj * self.w_norm_t[j].view(1, 1, K, P))
tmp = sum(vals) # (B,Ci,K,P)
out = torch.einsum('bckp,cop->bok', tmp, kernel_t)
return out
def _Convol_Torch(self, data: torch.Tensor, kernel: torch.Tensor, cell_ids=None) -> torch.Tensor:
"""
Convenience entry point with automatic single- or multi-gauge dispatch.
Behavior
--------
- If `cell_ids is None`: use cached geometry (prepare_torch) and sparse mapping
(bind_support_torch or bind_support_torch_multi) already stored in the class,
re-binding Step-B to `data`'s device/dtype when needed, then apply.
- If `cell_ids` is provided: compute geometry + sparse mapping for these cells
using the class' gauge setup (including the number of gauges G prepared by
`prepare_torch(..., G)`), reorder `data` to match the sorted ids, apply
(single or multi), and finally unsort to the original `cell_ids` order.
Parameters
----------
data : (B, Ci, K) torch.float
Sparse map values. Last axis K must equal the number of target pixels.
kernel : torch.Tensor
- Single-gauge path: (Ci, Co, P) where P = kernel_sz**2.
- Multi-gauge path: (Ci, Co_g, P) shared kernel for all gauges, OR
(G, Ci, Co_g, P) per-gauge kernels.
The output channels will be Co (single) or G*Co_g (multi).
cell_ids : Optional[np.ndarray | torch.Tensor], shape (K,)
Target HEALPix pixels. If None, re-use the class' cached targets.
Returns
-------
out : torch.Tensor, shape (B, Co, K)
Co = Co (single gauge) or Co = G*Co_g (multi-gauge).
"""
assert isinstance(data, torch.Tensor) and isinstance(kernel, torch.Tensor), \
"data and kernel must be torch.Tensors"
device = data.device
dtype = data.dtype
B, Ci, K_data = data.shape
P = self.P
P_k = kernel.shape[-1]
assert P_k == P, f"kernel P={P_k} must equal kernel_sz**2 = {P}"
def _to_np_1d(ids):
if isinstance(ids, torch.Tensor):
return ids.detach().cpu().numpy().astype(np.int64, copy=False)
return np.asarray(ids, dtype=np.int64).reshape(-1)
def _has_multi_bind():
return (getattr(self, 'G', 1) > 1 and
getattr(self, 'pos_safe_t_multi', None) is not None and
getattr(self, 'w_norm_t_multi', None) is not None)
# ----------------------------
# Case 1: new target ids given
# ----------------------------
if cell_ids is not None:
cell_ids_np = _to_np_1d(cell_ids)
# A) geometry with class' G (defaults to 1 if not set)
G = getattr(self, 'G', 1)
th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
self.prepare_torch(th, ph, alpha=None, G=G) # fills idx_t/_multi, w_t/_multi
# B) sort ids and reorder data accordingly
order = np.argsort(cell_ids_np)
ids_sorted_np = cell_ids_np[order]
assert K_data == ids_sorted_np.size, \
"data last dimension must equal number of provided cell_ids"
order_t = torch.as_tensor(order, device=device, dtype=torch.long)
data_sorted_t = data[..., order_t] # (B, Ci, K) aligned with ids_sorted_np
# C) bind sparse support
if G > 1:
self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=dtype)
out_sorted = self.apply_multi(data_sorted_t, kernel) # (B, G*Co_g, K)
else:
self.bind_support_torch(ids_sorted_np, device=device, dtype=dtype)
out_sorted = self.apply(data_sorted_t, kernel) # (B, Co, K)
# D) unsort back to original order
inv_order = np.empty_like(order)
inv_order[order] = np.arange(order.size)
inv_idx = torch.as_tensor(inv_order, device=device, dtype=torch.long)
return out_sorted[..., inv_idx]
# -----------------------------------------------
# Case 2: fast path on cached geometry + mapping
# -----------------------------------------------
if self.ids_sorted_np is None:
if getattr(self, 'cell_ids_default', None) is not None:
self.ids_sorted_np = np.sort(self.cell_ids_default)
else:
raise AssertionError(
"No cached targets. Either pass `cell_ids` once or initialize the class with `cell_ids=`."
)
if _has_multi_bind():
# rebind if device/dtype changed
if (self.device != device) or (self.dtype != dtype):
self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=dtype)
return self.apply_multi(data, kernel)
# single-gauge cached path
need_rebind = (
getattr(self, 'pos_safe_t', None) is None or
getattr(self, 'w_norm_t', None) is None or
self.device != device or
self.dtype != dtype
)
if need_rebind:
self.bind_support_torch(self.ids_sorted_np, device=device, dtype=dtype)
return self.apply(data, kernel)
[docs]
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
"""
Batched KERNELSZ x KERNELSZ aggregation (dispatcher).
Supports:
- im: Tensor (B, Ci, K) with
* cell_ids is None -> use cached targets (fast path)
* cell_ids is 1D (K,) -> one shared grid for whole batch
* cell_ids is 2D (B, K) -> per-sample grids, same length; returns (B, Co, K)
* cell_ids is list/tuple -> per-sample grids (var-length allowed)
- im: list/tuple of Tensors, each (Ci, K_b) with cell_ids list/tuple
Notes
-----
- Kernel shapes accepted:
* single/multi shared: (Ci, Co_g, P)
* per-gauge kernels: (G, Ci, Co_g, P)
The low-level _Convol_Torch will choose between apply/apply_multi
depending on the class state (G>1 and multi-bind present).
"""
import numpy as np
import torch
def _dev_dtype_like(x: torch.Tensor):
if not isinstance(x, torch.Tensor):
raise TypeError("Expected a torch.Tensor for device/dtype inference.")
return x.device, x.dtype
def _prepare_kernel(k: torch.Tensor, device, dtype):
if not isinstance(k, torch.Tensor):
raise TypeError("kernel (ww) must be a torch.Tensor")
return k.to(device=device, dtype=dtype)
def _to_np_ids(ids):
if isinstance(ids, torch.Tensor):
return ids.detach().cpu().numpy().astype(np.int64, copy=False)
return np.asarray(ids, dtype=np.int64)
class _NsideContext:
def __init__(self, obj, nside_new):
self.obj = obj
self.nside_old = obj.nside
self.nside_new = int(nside_new) if nside_new is not None else obj.nside
def __enter__(self):
self.obj.nside = self.nside_new
return self
def __exit__(self, exc_type, exc, tb):
self.obj.nside = self.nside_old
# ---------------- main dispatcher ----------------
if isinstance(im, torch.Tensor):
device, dtype = _dev_dtype_like(im)
kernel = _prepare_kernel(ww, device, dtype)
with _NsideContext(self, nside):
# (A) Fast path: no ids provided -> delegate fully to _Convol_Torch
if cell_ids is None:
return self._Convol_Torch(im, kernel, cell_ids=None)
# Normalise numpy/tensor ragged inputs
if isinstance(cell_ids, np.ndarray) and cell_ids.dtype == object:
cell_ids = list(cell_ids)
# (B) One shared grid for entire batch: 1-D ids
if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 1) == 1:
return self._Convol_Torch(im, kernel, cell_ids=_to_np_ids(cell_ids))
# (C) Per-sample grids, same length: 2-D ids (B, K)
if isinstance(cell_ids, (np.ndarray, torch.Tensor)) and getattr(cell_ids, "ndim", 0) == 2:
B = im.shape[0]
if isinstance(cell_ids, torch.Tensor):
assert cell_ids.shape[0] == B, "cell_ids first dim must match batch size B"
ids2d = cell_ids.detach().cpu().numpy().astype(np.int64, copy=False)
else:
ids2d = np.asarray(cell_ids, dtype=np.int64)
assert ids2d.shape[0] == B, "cell_ids first dim must match batch size B"
outs = []
for b in range(B):
x_b = im[b:b+1] # (1, Ci, K_b)
ids_b = ids2d[b] # (K_b,)
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
outs.append(y_b)
return torch.cat(outs, dim=0) # (B, Co, K)
# (D) Per-sample grids, variable length: list/tuple
if isinstance(cell_ids, (list, tuple)):
B = im.shape[0]
assert len(cell_ids) == B, "cell_ids list length must match batch size B"
outs = []
lengths = []
for b in range(B):
ids_b_np = _to_np_ids(cell_ids[b])
lengths.append(ids_b_np.size)
x_b = im[b:b+1] # (1, Ci, K_b)
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b_np) # (1, Co, K_b)
outs.append(y_b)
if len(set(lengths)) == 1:
return torch.cat(outs, dim=0) # (B, Co, K)
else:
return [y.squeeze(0) for y in outs] # list[(Co, K_b)]
raise TypeError("Unsupported type for cell_ids with tensor input.")
# Case: im is list/tuple of (Ci, K_b) tensors (var-length samples)
if isinstance(im, (list, tuple)):
assert isinstance(cell_ids, (list, tuple)) and len(cell_ids) == len(im), \
"When im is a list, cell_ids must be a list of same length."
assert len(im) > 0, "Empty list for `im`."
device, dtype = _dev_dtype_like(im[0])
kernel = _prepare_kernel(ww, device, dtype)
outs = []
with _NsideContext(self, nside):
lengths = []
tmp = []
for x_b, ids_b in zip(im, cell_ids):
assert isinstance(x_b, torch.Tensor), "Each sample in `im` must be a torch.Tensor"
assert x_b.device == device and x_b.dtype == dtype, "All samples must share device/dtype."
x_b = x_b.unsqueeze(0) # (1, Ci, K_b)
ids_b = _to_np_ids(ids_b)
y_b = self._Convol_Torch(x_b, kernel, cell_ids=ids_b) # (1, Co, K_b)
tmp.append(y_b)
lengths.append(y_b.shape[-1])
if len(set(lengths)) == 1:
return torch.cat(tmp, dim=0) # (B, Co, K)
else:
return [y.squeeze(0) for y in tmp]
raise TypeError("`im` must be either a torch.Tensor (B,Ci,K) or a list of (Ci,K_b) tensors.")
[docs]
def make_matrix(
self,
kernel: torch.Tensor,
cell_ids=None,
*,
return_sparse_tensor: bool = False,
chunk_k: int = 4096,
):
"""
Build the sparse COO matrix M such that applying M to vec(data) reproduces
the spherical convolution performed by Convol_torch/_Convol_Torch.
Supports single- and multi-gauge:
- kernel shape (Ci, Co_g, P) -> shared across G gauges, output Co = G*Co_g
- kernel shape (G, Ci, Co_g, P) -> per-gauge kernels, same output Co = G*Co_g
Parameters
----------
kernel : torch.Tensor
(Ci, Co_g, P) or (G, Ci, Co_g, P) with P = kernel_sz**2.
Must be on the device/dtype where you want the resulting matrix.
cell_ids : array-like of shape (K,) or torch.Tensor, optional
Target pixel IDs (NESTED if self.nest=True).
If None, uses the grid already cached in the class (fast path).
If provided, we prepare geometry & sparse binding for these ids.
return_sparse_tensor : bool, default False
If True, return a coalesced torch.sparse_coo_tensor of shape (Co*K, Ci*K).
Else, return (weights, indices, shape) where:
- indices is a LongTensor of shape (2, nnz) with [row; col]
- weights is a Tensor of shape (nnz,)
- shape is the (rows, cols) tuple
chunk_k : int, default 4096
Chunk size over target pixels to limit peak memory.
Returns
-------
If return_sparse_tensor:
M : torch.sparse_coo_tensor of shape (Co*K, Ci*K), coalesced
else:
weights : torch.Tensor (nnz,)
indices : torch.LongTensor (2, nnz) with [row; col]
shape : tuple[int, int] (Co*K, Ci*K)
Notes
-----
- The resulting matrix implements the same interpolation-and-mixing as the
GPU path (gather 4 neighbors -> normalize -> apply spatial+channel kernel),
and matches the output of Convol_torch for the same (kernel, cell_ids).
- For multi-gauge, rows are grouped as concatenated gauges: first all
Co_g channels for gauge 0 over all K, then gauge 1, etc.
"""
import numpy as np
import torch
import healpy as hp
device = kernel.device
k_dtype = kernel.dtype
# --- validate kernel & normalize shapes
if kernel.dim() == 3:
# shared across gauges
Ci, Co_g, P = kernel.shape
per_gauge = False
elif kernel.dim() == 4:
Gk, Ci, Co_g, P = kernel.shape
per_gauge = True
if hasattr(self, 'G'):
assert Gk == self.G, f"kernel first dim G={Gk} must match self.G={self.G}"
else:
self.G = int(Gk)
else:
raise ValueError("kernel must be (Ci,Co_g,P) or (G,Ci,Co_g,P)")
assert P == self.P, f"kernel P={P} must equal kernel_sz**2={self.P}"
# --- geometry + binding for these ids (or use cached)
def _to_np_ids(ids):
if ids is None:
return None
if isinstance(ids, torch.Tensor):
return ids.detach().cpu().numpy().astype(np.int64, copy=False).reshape(-1)
return np.asarray(ids, dtype=np.int64).reshape(-1)
cell_ids_np = _to_np_ids(cell_ids)
if cell_ids_np is not None:
# Step A: geometry (Torch) with the class' number of gauges
G = int(getattr(self, 'G', 1))
th, ph = hp.pix2ang(self.nside, cell_ids_np, nest=self.nest)
self.prepare_torch(th, ph, alpha=None, G=G)
# Step B: bind on sorted ids, and remember K
order = np.argsort(cell_ids_np)
ids_sorted_np = cell_ids_np[order]
K = ids_sorted_np.size
if G > 1:
self.bind_support_torch_multi(ids_sorted_np, device=device, dtype=k_dtype)
else:
self.bind_support_torch(ids_sorted_np, device=device, dtype=k_dtype)
else:
# use cached mapping
if getattr(self, 'ids_sorted_np', None) is None:
raise AssertionError("No cached targets; pass `cell_ids` or init the class with `cell_ids=`.")
K = self.ids_sorted_np.size
# rebind to the kernel device/dtype if needed
if getattr(self, 'G', 1) > 1:
if (self.device != device) or (self.dtype != k_dtype):
self.bind_support_torch_multi(self.ids_sorted_np, device=device, dtype=k_dtype)
else:
if (self.device != device) or (self.dtype != k_dtype):
self.bind_support_torch(self.ids_sorted_np, device=device, dtype=k_dtype)
G = int(getattr(self, 'G', 1))
Co_total = (G * Co_g) # output channels including gauges
shape = (Co_total * K, Ci * K)
# --- choose mapping tensors (multi vs single)
is_multi = (G > 1) and (getattr(self, 'pos_safe_t_multi', None) is not None)
if is_multi:
pos_all_g = self.pos_safe_t_multi.to(device=device) # (G,4,K*P)
w_all_g = self.w_norm_t_multi.to(device=device, dtype=k_dtype)
else:
pos_all = self.pos_safe_t.to(device=device) # (4,K*P)
w_all = self.w_norm_t.to(device=device, dtype=k_dtype)
# --- precompute channel row/col bases
# rows: for (co_total, k_out) -> co_total*K + k_out
# cols: for (ci, k_in) -> ci*K + k_in
row_base = (torch.arange(Co_total, device=device, dtype=torch.long) * K)[:, None] # (Co_total, 1)
col_base = (torch.arange(Ci, device=device, dtype=torch.long) * K)[:, None] # (Ci, 1)
rows_all, cols_all, vals_all = [], [], []
# --- helper to add one gauge block (gauge g -> Co_g*K rows)
def _accumulate_for_gauge(g, pos_g, w_g, ker_g):
"""
pos_g : (4, K*P) long
w_g : (4, K*P) float
ker_g : (Ci, Co_g, P)
"""
# process by chunks in k to control memory
for start in range(0, K, chunk_k):
stop = min(start + chunk_k, K)
Kb = stop - start
cols_span = torch.arange(start * self.P, stop * self.P, device=device, dtype=torch.long)
pos = pos_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
w = w_g[:, cols_span].view(4, Kb, self.P) # (4, Kb, P)
# rows_gauge: indices de lignes pour cette jauge g
# Chaque jauge occupe un bloc de Co_g canaux de sortie pour CHAQUE pixel (K)
# donc offset = g*Co_g
rows_gauge = (torch.arange(Co_g, device=device, dtype=torch.long) + g*Co_g)[:, None] * K \
+ (start + torch.arange(Kb, device=device, dtype=torch.long))[None, :]
# -> shape (Co_g, Kb)
rows = rows_gauge[:, :, None, None, None] # (Co_g, Kb,1,1,1)
rows = rows.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
# cols: indices colonnes = (ci*K + pix)
cols_pix = pos.permute(1, 0, 2) # (Kb, 4, P)
cols_pix = cols_pix[None, :, None, :, :] # (1, Kb, 1, 4, P)
cols = col_base + cols_pix # (Ci, Kb, 1, 4, P)
cols = cols.permute(2, 1, 0, 3, 4) # (1, Kb, Ci, 4, P)
cols = cols.expand(Co_g, Kb, Ci, 4, self.P)
# values = kernel(ci, co_g, p) * w(4,kb,p)
k_exp = ker_g.permute(1, 0, 2) # (Co_g, Ci, P)
k_exp = k_exp[:, None, :, None, :] # (Co_g, 1, Ci, 1, P)
# CORRECTION: remettre les axes de w en (Kb,4,P) avant broadcast
w_exp = w.permute(1, 0, 2)[None, :, None, :, :] # (1, Kb, 1, 4, P)
w_exp = w_exp.expand(Co_g, Kb, Ci, 4, self.P) # (Co_g, Kb, Ci, 4, P)
vals = k_exp * w_exp # (Co_g, Kb, Ci, 4, P)
rows_all.append(rows.reshape(-1))
cols_all.append(cols.reshape(-1))
vals_all.append(vals.reshape(-1))
# --- accumulate either single- or multi-gauge
if is_multi:
# (a) shared kernel (Ci, Co_g, P) -> repeat over gauges
if not per_gauge and kernel.dim() == 3:
for g in range(G):
_accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel.to(device=device, dtype=k_dtype))
# (b) per-gauge kernel (G, Ci, Co_g, P)
else:
for g in range(G):
_accumulate_for_gauge(g, pos_all_g[g], w_all_g[g], kernel[g].to(device=device, dtype=k_dtype))
else:
# G == 1 (single-gauge path)
g = 0
_accumulate_for_gauge(g, pos_all, w_all, kernel if kernel.dim() == 3 else kernel[0])
rows = torch.cat(rows_all, dim=0)
cols = torch.cat(cols_all, dim=0)
vals = torch.cat(vals_all, dim=0)
indices = torch.stack([cols, rows], dim=0)
if return_sparse_tensor:
M = torch.sparse_coo_tensor(indices, vals, size=shape, device=device, dtype=k_dtype).coalesce()
return M
else:
return vals, indices, shape
[docs]
def to_tensor(self,x):
return torch.tensor(x,device=self.device)
[docs]
def to_numpy(self,x):
if isinstance(x,np.ndarray):
return x
return x.cpu().numpy()