import numpy as np
import matplotlib.pyplot as plt
import healpy as hp
from scipy.sparse import csr_array
import torch
import foscat.scat_cov as sc
from scipy.spatial import cKDTree
[docs]
class HOrientedConvol:
def __init__(self,
nside,
KERNELSZ,
cell_ids=None,
nest=True,
device=None,
dtype='float64',
polar=False,
gamma=1.0,
allow_extrapolation=True,
no_cell_ids=False,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if dtype=='float64':
self.dtype=torch.float64
else:
self.dtype=torch.float32
if KERNELSZ % 2 == 0:
raise ValueError(f"N must be odd so that coordinates are integers from -K..K; got N={KERNELSZ}.")
self.local_test=False
if no_cell_ids==True:
cell_ids=np.arange(10)
if cell_ids is None:
self.cell_ids=np.arange(12*nside**2)
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
KERNELSZ*KERNELSZ,
nside,
nest=nest,
)
else:
try:
self.cell_ids=cell_ids.cpu().numpy()
except:
self.cell_ids=cell_ids
self.local_test=True
if self.cell_ids.ndim==1:
idx_nn = self.knn_healpix_ckdtree(self.cell_ids,
KERNELSZ*KERNELSZ,
nside,
nest=nest,
)
else:
idx_nn = []
for k in range(self.cell_ids.shape[0]):
idx_nn.append(self.knn_healpix_ckdtree(self.cell_ids[k],
KERNELSZ*KERNELSZ,
nside,
nest=nest,
))
idx_nn=np.stack(idx_nn,0)
if self.cell_ids.ndim==1:
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids,nest=nest)
if self.local_test:
t,p = hp.pix2ang(nside,self.cell_ids[idx_nn],nest=True)
else:
t,p = hp.pix2ang(nside,idx_nn,nest=True)
self.t=t[:,0]
self.p=p[:,0]
vec_orig=hp.ang2vec(t,p)
self.vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
'''
if self.local_test:
idx_nn=self.remap_by_first_column(idx_nn)
'''
del mat_pt
del vec_orig
else:
t,p,vec_rot = [],[],[]
for k in range(self.cell_ids.shape[0]):
mat_pt=self.rotation_matrices_from_healpix(nside,self.cell_ids[k],nest=nest)
lt,lp = hp.pix2ang(nside,self.cell_ids[k,idx_nn[k]],nest=True)
vec_orig=hp.ang2vec(lt,lp)
l_vec_rot=np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
vec_rot.append(l_vec_rot)
del vec_orig
del mat_pt
t.append(lt[:,0])
p.append(lp[:,0])
self.t=np.stack(t,0)
self.p=np.stack(p,0)
self.vec_rot=np.stack(vec_rot,0)
del t
del p
del vec_rot
self.polar=polar
self.gamma=gamma
self.device=device
self.allow_extrapolation=allow_extrapolation
self.w_idx=None
self.idx_nn=idx_nn
self.nside=nside
self.KERNELSZ=KERNELSZ
self.nest=nest
self.f=None
[docs]
def remap_by_first_column(self,idx: np.ndarray) -> np.ndarray:
"""
Remap the values in `idx` so that:
- The first column becomes [0, 1, ..., N-1]
- All other columns are updated accordingly using the same mapping.
Parameters
----------
idx : np.ndarray
Integer array of shape (N, m).
Assumes all values in idx are present in the first column (otherwise they get -1).
Returns
-------
np.ndarray
New array with remapped indices.
"""
if idx.ndim != 2:
raise ValueError("idx must be a 2D array of shape (N, m)")
N, m = idx.shape
# Create a mapping: original_value_in_first_column -> row_index
# Example: if idx[:,0] = [101, 505, 303], then mapping = {101:0, 505:1, 303:2}
keys = idx[:, 0]
mapping = {v: i for i, v in enumerate(keys)}
# Optional check: ensure all values are in the mapping keys
# If not, you can raise an error or handle it differently
# if not np.isin(idx, keys).all():
# missing = np.unique(idx[~np.isin(idx, keys)])
# raise ValueError(f"Some values are not in idx[:,0]: {missing}")
# Function to get mapped value, or -1 if value is not found
get = mapping.get
# Apply mapping to all elements (vectorized via np.vectorize)
out = np.vectorize(lambda v: get(int(v), -1), otypes=[int])(idx)
return out
[docs]
def rotation_matrices_from_healpix(self,nside, hpix_idx, nest=True):
"""
Compute rotation matrices that move each Healpix pixel center to the North pole.
equivalent to rotation matrices R_z(phi) * R_y(-thi) for N points.
Parameters
----------
nside : int
Healpix Nside resolution.
hpix_idx : array_like, shape (N,)
Healpix pixel indices.
nest : bool, optional
True if indices are in NESTED ordering, False for RING ordering.
Returns
-------
R : ndarray, shape (3, 3, N)
Rotation matrices for each pixel index.
"""
try:
hpix_idx = np.asarray(hpix_idx)
except:
hpix_idx = hpix_idx.cpu().numpy()
N = hpix_idx.shape[0]
# Get angular coordinates of each pixel center
theta, phi = hp.pix2ang(nside, hpix_idx, nest=nest) # theta: colatitude (0=north pole)
# Precompute sines/cosines
cphi = np.cos(phi)
sphi = np.sin(phi)
cthi = np.cos(-theta)
sthi = np.sin(-theta)
# Rotation around Z (by phi)
Rz = np.zeros((3, 3, N))
Rz[0, 0, :] = cphi
Rz[0, 1, :] = -sphi
Rz[1, 0, :] = sphi
Rz[1, 1, :] = cphi
Rz[2, 2, :] = 1.0
# Rotation around Y (by -theta)
Ry = np.zeros((3, 3, N))
Ry[0, 0, :] = cthi
Ry[0, 2, :] = -sthi
Ry[1, 1, :] = 1.0
Ry[2, 0, :] = sthi
Ry[2, 2, :] = cthi
# Multiply Rz * Ry for each pixel
R = np.einsum('ijk,jlk->ilk', Rz, Ry)
return R
def _choose_depth_for_candidates(self, N, overshoot=2, max_depth=12):
"""
Pick hierarchy depth d so that ~ 9 * 4**d >= overshoot * N.
Depth 0 => 9 candidates; 1 => 36; 2 => 144; 3 => 576; 4 => 2304; etc.
"""
d = 0
while 9 * (4 ** d) < overshoot * N and d < max_depth:
d += 1
return d
[docs]
def knn_healpix_ckdtree(self,
hidx, N, nside, *, nest=True,
include_self=True,
vec_dtype=np.float32,
out_dtype=np.int64
):
"""
k-NN using a cKDTree on unit vectors (exact in Euclidean space).
Returns LOCAL indices (0..M-1) of the N nearest neighbours per row.
"""
try:
hidx = np.asarray(hidx, dtype=np.int64)
except:
hidx = hidx.cpu().numpy()
if hidx.ndim != 1:
raise ValueError("hidx must be 1D")
M = hidx.size
if M == 0:
return np.empty((0, 0), dtype=out_dtype)
if N <= 0:
raise ValueError("N must be >= 1")
# Effective N
N_eff = min(N, M if include_self else max(M-1, 1))
# Build unit vectors
hidx_n = hidx if nest else hp.ring2nest(nside, hidx)
x, y, z = hp.pix2vec(nside, hidx_n, nest=True)
V = np.stack([x, y, z], axis=1).astype(vec_dtype, copy=False) # (M,3)
tree = cKDTree(V)
if include_self:
# Self appears with distance 0 as the first neighbour
d, idx = tree.query(V, k=N_eff, workers=-1) # idx shape (M,N)
return idx.astype(out_dtype, copy=False)
else:
# Ask for one extra and drop self
k = min(N_eff + 1, M)
d, idx = tree.query(V, k=k, workers=-1)
# idx can be (M,) if k==1; normalize shapes
if idx.ndim == 1:
idx = idx[:, None]
# Remove self if present (distance 0)
out = np.empty((M, N_eff), dtype=out_dtype)
for i in range(M):
row = idx[i]
# filter out self (i); keep first N_eff
row = row[row != i][:N_eff]
# if M==N and no self, row already size N_eff
out[i, :row.size] = row
if row.size < N_eff:
# extremely rare (degenerate duplicates); fallback by scores
cand = np.setdiff1d(np.arange(M), np.r_[i, row], assume_unique=False)
# pick nearest remaining
di, ci = tree.query(V[i], k=N_eff - row.size)
out[i, row.size:] = np.atleast_1d(ci).astype(out_dtype, copy=False)
return out
[docs]
def make_wavelet_matrix(self,
orientations,
polar=True,
norm_mean=True,
norm_std=True,
return_index=False,
return_smooth=False,
):
sigma_gauss = 0.5
sigma_cosine = 0.5
if self.KERNELSZ == 3:
sigma_gauss = 1.0 / np.sqrt(2)
sigma_cosine = 1.0
orientations=np.asarray(orientations)
NORIENT = orientations.shape[0]
rotate=2*((self.t<np.pi/2)-0.5)[None,:,None]
if polar:
xx=np.cos(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,0]-rotate*np.sin(self.p[None,:]+np.pi/2-orientations[:,None])[:,:,None]*self.vec_rot[None,:,:,1]
else:
xx=np.cos(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,0]-np.sin(np.pi/2-orientations[:,None,None])*self.vec_rot[None,:,:,1]
r=(self.vec_rot[None,:,:,0]**2+self.vec_rot[None,:,:,1]**2+(self.vec_rot[None,:,:,2]-1.0)**2)
if return_smooth:
wsmooth=np.exp(-sigma_gauss*r*self.nside**2)
if norm_std:
ww=np.sum(wsmooth,2)
wsmooth = wsmooth/ww[:,:,None]
#for consistency with previous definition
w=np.exp(-sigma_gauss*r*self.nside**2)*(np.cos(xx*self.nside*sigma_cosine*np.pi)-1J*np.sin(xx*self.nside*sigma_cosine*np.pi))
if norm_std:
ww=1/np.sum(abs(w),2)[:,:,None]
else:
ww=1.0
if norm_mean:
w = (w.real-np.mean(w.real,2)[:,:,None]+1J*(w.imag-np.mean(w.imag,2)[:,:,None]))*ww
NK=self.idx_nn.shape[1]
indice_1_0 = np.tile(self.idx_nn.flatten(),NORIENT)
indice_1_1 = np.tile(np.repeat(self.idx_nn[:,0],NK),NORIENT)+ \
np.repeat(np.arange(NORIENT),self.idx_nn.shape[0]*self.idx_nn.shape[1])*self.idx_nn.shape[0]
w = w.flatten()
if return_smooth:
indice_2_0 = self.idx_nn.flatten()
indice_2_1 = np.repeat(self.idx_nn[:,0],NK)
wsmooth = wsmooth.flatten()
if return_index:
if return_smooth:
return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1),wsmooth,np.concatenate([indice_2_0[:,None],indice_2_1[:,None]],1)
return w,np.concatenate([indice_1_0[:,None],indice_1_1[:,None]],1)
return csr_array((w, (indice_1_0, indice_1_1)), shape=(12*self.nside**2, 12*self.nside**2*NORIENT))
[docs]
def make_idx_weights_from_cell_ids(self,
i_cell_ids,
polar=False,
gamma=1.0,
device=None,
allow_extrapolation=True):
"""
Accept 1D (Npix,) or 2D (B, Npix) cell_ids and return
tensors batched on the first dim (B, ...).
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# → cast numpy
if torch.is_tensor(i_cell_ids):
cid = i_cell_ids.detach().cpu().numpy()
else:
cid = np.asarray(i_cell_ids)
# --- 1D: pas de boucle, on calcule une fois, puis on ajoute l'axe batch
if cid.ndim == 1:
l_idx_nn, l_w_idx, l_w_w = self.make_idx_weights_from_one_cell_ids(
cid, polar=polar, gamma=gamma, device=device,
allow_extrapolation=allow_extrapolation
)
idx_nn = torch.as_tensor(l_idx_nn, device=device, dtype=torch.long)[None, ...] # (1, Npix, P)
w_idx = torch.as_tensor(l_w_idx, device=device, dtype=torch.long)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
w_w = torch.as_tensor(l_w_w, device=device, dtype=self.dtype)[None, ...] # (1, Npix, S, P) ou (1, Npix, P)
return idx_nn, w_idx, w_w
# --- 2D: boucle sur b, empilement en (B, ...)
elif cid.ndim == 2:
outs = [ self.make_idx_weights_from_one_cell_ids(
cid[k], polar=polar, gamma=gamma, device=device,
allow_extrapolation=allow_extrapolation)
for k in range(cid.shape[0]) ]
idx_nn = torch.as_tensor(np.stack([o[0] for o in outs], axis=0), device=device, dtype=torch.long)
w_idx = torch.as_tensor(np.stack([o[1] for o in outs], axis=0), device=device, dtype=torch.long)
w_w = torch.as_tensor(np.stack([o[2] for o in outs], axis=0), device=device, dtype=self.dtype)
return idx_nn, w_idx, w_w
else:
raise ValueError(f"Unsupported cell_ids ndim={cid.ndim}; expected 1 or 2.")
'''
def make_idx_weights_from_cell_ids(self,i_cell_ids,
polar=False,
gamma=1.0,
device=None,
allow_extrapolation=True):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if len(i_cell_ids.shape)<2:
cell_ids=i_cell_ids
n_cids=1
else:
cell_ids=i_cell_ids[0]
n_cids=i_cell_ids.shape[0]
idx_nn,w_idx,w_w = [],[],[]
for k in range(n_cids):
cell_ids=i_cell_ids[k]
l_idx_nn,l_w_idx,l_w_w = self.make_idx_weights_from_one_cell_ids(cell_ids,
polar=polar,
gamma=gamma,
device=device,
allow_extrapolation=allow_extrapolation)
idx_nn.append(l_idx_nn)
w_idx.append(l_w_idx)
w_w.append(l_w_w)
idx_nn = torch.Tensor(np.stack(idx_nn,0)).to(device=device, dtype=torch.long)
w_idx = torch.Tensor(np.stack(w_idx,0)).to(device=device, dtype=torch.long)
w_w = torch.Tensor(np.stack(w_w,0)).to(device=device, dtype=self.dtype)
return idx_nn,w_idx,w_w
'''
[docs]
def make_idx_weights_from_one_cell_ids(self,
cell_ids,
polar=False,
gamma=1.0,
device=None,
allow_extrapolation=True):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
idx_nn = self.knn_healpix_ckdtree(cell_ids,
self.KERNELSZ*self.KERNELSZ,
self.nside,
nest=self.nest,
)
mat_pt=self.rotation_matrices_from_healpix(self.nside,cell_ids,nest=self.nest)
t,p = hp.pix2ang(self.nside,cell_ids[idx_nn],nest=self.nest)
vec_orig=hp.ang2vec(t,p)
vec_rot = np.einsum('mki,ijk->kmj', vec_orig,mat_pt)
del vec_orig
del mat_pt
rotate=2*((t<np.pi/2)-0.5)[:,None]
if polar:
xx=np.cos(p)[:,None]*vec_rot[:,:,0]-rotate*np.sin(p)[:,None]*vec_rot[:,:,1]
yy=-np.sin(p)[:,None]*vec_rot[:,:,0]-rotate*np.cos(p)[:,None]*vec_rot[:,:,1]
else:
xx=vec_rot[:,:,0]
yy=vec_rot[:,:,1]
del vec_rot
del rotate
del t
del p
w_idx,w_w = self.bilinear_weights_NxN(xx*self.nside*gamma,
yy*self.nside*gamma,
allow_extrapolation=allow_extrapolation)
'''
# calib : [Npix, K]
calib = np.zeros((w_idx.shape[0], w_idx.shape[2]))
# Assumptions:
# w_idx.shape == (Npix, M, K) et w_w.shape == (Npix, M, K)
Npix, M, K = w_idx.shape
nb_cols = K
# 1) Accumulation via "bincount" with row offset
row_ids = np.arange(Npix, dtype=np.int64)[:, None, None] * nb_cols
flat_idx = (row_ids + w_idx).ravel() # indices dans [0, Npix*9)
weights = w_w.ravel().astype(np.float64) # ou dtype de ton choix
calib = np.bincount(flat_idx, weights, minlength=Npix*nb_cols)\
.reshape(Npix, nb_cols)
# 2) Write back into norm_a according to w_idx
norm_a = calib[np.arange(Npix)[:, None, None], w_idx]
w_w /= norm_a
w_w = np.clip(w_w,0.0,1.0)
w_w[np.isnan(w_w)]=0.0
'''
#del xx
#del yy
return idx_nn,w_idx,w_w,xx,yy
[docs]
def make_idx_weights(self,polar=False,gamma=1.0,device=None,allow_extrapolation=True,return_index=False):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
idx_nn,w_idx,w_w = self.make_idx_weights_from_one_cell_ids(self.cell_ids,
polar=polar,
gamma=gamma,
device=device,
allow_extrapolation=allow_extrapolation)
# Ensure types/devices
self.idx_nn = torch.Tensor(idx_nn).to(device=device, dtype=torch.long)
self.w_idx = torch.Tensor(w_idx).to(device=device, dtype=torch.long)
self.w_w = torch.Tensor(w_w).to(device=device, dtype=self.dtype)
def _grid_index(self, xi, yi):
"""
Map integer grid coords (xi, yi) in {-1,0,1} to flat index in [0..8]
following the given order (row-major from y=-1 to y=1).
"""
return (yi + self.KERNELSZ//2) * self.KERNELSZ + (xi + self.KERNELSZ//2)
[docs]
def bilinear_weights_NxN(self,x, y, allow_extrapolation=True):
"""
Compute bilinear weights on an N×N integer grid with node coordinates
(xi, yi) in {-K, ..., +K} × {-K, ..., +K}, where K = N//2 (N must be odd).
N is attached to the class `N = self.KERNELSZ`
The query point (x, y) is continuous in the same coordinate system.
For each query, we pick the unit cell [x0, x0+1] × [y0, y0+1] with
integer corners (x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1), and compute
standard bilinear weights relative to (x0, y0).
Parameters
----------
x, y : float or array-like of shape (M,)
Query coordinates in the integer grid coordinate system.
N : int
Grid size (must be odd). Grid nodes are at integer coords
xi, yi ∈ {-K, ..., +K}, where K = N//2.
allow_extrapolation : bool, default True
- If False: clamp (x, y) to [-K, +K] so that tx, ty ∈ [0, 1] and
weights are non-negative and sum to 1.
- If True : do not clamp (x, y); we still select the nearest boundary
cell inside the grid for the indices, but tx, ty may fall outside
[0, 1], yielding extrapolation (weights can be negative).
Returns
-------
idx : ndarray of shape (M, 4), dtype=int64
Flat indices (0 .. N*N-1) of the four cell-corner nodes in row-major
order (y from -K to +K, x from -K to +K):
order = [(x0,y0), (x0+1,y0), (x0,y0+1), (x0+1,y0+1)].
w : ndarray of shape (M, 4), dtype=float64
Corresponding bilinear weights for each query point. If
allow_extrapolation=False and the point is inside the grid, each row
sums to 1 and all weights are in [0,1].
Notes
-----
- This matches your previous 3×3 case when N=3, with the same row-major
flattening convention.
- For extrapolation=True, indices are kept in-bounds (clamped to boundary
cells), while tx, ty > 1 or < 0 are allowed.
"""
# --- checks & shapes ---
N=self.KERNELSZ
K = N // 2
x = np.atleast_1d(np.asarray(x, dtype=float))
y = np.atleast_1d(np.asarray(y, dtype=float))
if x.shape != y.shape:
raise ValueError("x and y must have the same shape")
M = x.shape[0]
# --- optionally clamp queries (for pure interpolation) ---
if not allow_extrapolation:
x = np.clip(x, -K, K)
y = np.clip(y, -K, K)
# --- choose the cell: x0=floor(x), y0=floor(y), but keep indices in-bounds
# cell must be inside [-K..K-1] × [-K..K-1] so that +1 is valid
x0 = np.floor(x)
y0 = np.floor(y)
x0 = np.clip(x0, -K, K - 1).astype(int)
y0 = np.clip(y0, -K, K - 1).astype(int)
x1 = x0 + 1
y1 = y0 + 1
# --- local coords within the cell (unit spacing) ---
tx = x - x0
ty = y - y0
# --- bilinear weights ---
# (x0,y0) w00, (x1,y0) w10, (x0,y1) w01, (x1,y1) w11
w00 = (1.0 - tx) * (1.0 - ty)
w10 = tx * (1.0 - ty)
w01 = (1.0 - tx) * ty
w11 = tx * ty
w = np.stack([w00, w10, w01, w11], axis=1)
# --- flat indices in row-major order (y changes slowest) ---
# index = (yi + K) * N + (xi + K)
def flat_idx(xi, yi):
return (yi + K) * N + (xi + K)
i00 = flat_idx(x0, y0)
i10 = flat_idx(x1, y0)
i01 = flat_idx(x0, y1)
i11 = flat_idx(x1, y1)
idx = np.stack([i00, i10, i01, i11], axis=1).astype(np.int64)
return idx, w
# --- Add inside class HOrientedConvol, just above Convol_torch ---
def _convol_single(self, im1: torch.Tensor, ww: torch.Tensor, cell_ids=None, nside=None):
"""
Single-sample path. im1: (1, C_i, Npix_1). Returns (1, C_o, Npix_1).
"""
if not isinstance(im1, torch.Tensor):
im1 = torch.as_tensor(im1, device=self.device, dtype=self.dtype)
if not isinstance(ww, torch.Tensor):
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
assert im1.ndim == 3 and im1.shape[0] == 1, f"expected (1, C_i, Npix), got {tuple(im1.shape)}"
# Reuse the existing Convol_torch core by faking B=1 shapes.
# We call the existing (batched) implementation with B=1.
return self.Convol_torch(im1, ww, cell_ids=cell_ids, nside=nside) # returns (1, C_o, Npix_1)
# --- Replace the first lines of Convol_torch with a dispatcher ---
[docs]
def Convol_torch(self, im, ww, cell_ids=None, nside=None):
"""
Batched KERNELSZxKERNELSZ aggregation.
Accepts either:
- im: Tensor (B, C_i, Npix) with one shared or per-batch (B,Npix) cell_ids
- im: list/tuple of Tensors, each (C_i, Npix_b), with cell_ids a list of arrays
"""
import torch
# (A) Variable-length per-sample path: im is a list/tuple OR cell_ids is a list/tuple
if isinstance(im, (list, tuple)) or isinstance(cell_ids, (list, tuple)):
# Normalize to lists
im_list = im if isinstance(im, (list, tuple)) else [im]
cid_list = cell_ids if isinstance(cell_ids, (list, tuple)) else [cell_ids] * len(im_list)
assert len(im_list) == len(cid_list), "im list and cell_ids list must have same length"
outs = []
for xb, cb in zip(im_list, cid_list):
# xb: (C_i, Npix_b) -> (1, C_i, Npix_b)
if not torch.is_tensor(xb):
xb = torch.as_tensor(xb, device=self.device, dtype=self.dtype)
if xb.dim() == 2:
xb = xb.unsqueeze(0)
elif xb.dim() != 3 or xb.shape[0] != 1:
raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
yb = self._convol_single(xb, ww, cell_ids=cb, nside=nside) # (1, C_o, Npix_b)
outs.append(yb.squeeze(0)) # -> (C_o, Npix_b)
return outs # List[Tensor], each (C_o, Npix_b)
# (B) Standard fixed-length batched path (your current implementation)
# ... keep your existing Convol_torch body from here unchanged ...
# (paste your current function body starting from the type casting and assertions)
# ---- Basic checks / casting ----
if not isinstance(im, torch.Tensor):
im = torch.as_tensor(im, device=self.device, dtype=self.dtype)
if not isinstance(ww, torch.Tensor):
ww = torch.as_tensor(ww, device=self.device, dtype=self.dtype)
assert im.ndim == 3, f"`im` must be (B, C_i, Npix), got {tuple(im.shape)}"
B, C_i, Npix = im.shape
device = im.device
dtype = im.dtype
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
# target shapes:
# idx_nn_eff : (B, Npix, P)
# w_idx_eff : (B, Npix, S, P)
# w_w_eff : (B, Npix, S, P)
if cell_ids is not None:
# ---- Recompute (idx_nn, w_idx, w_w) depending on cell_ids shape ----
# Normalise: accept Tensor, ndarray, or list/tuple of 1 element (var-length case, B=1)
if isinstance(cell_ids, (list, tuple)):
# liste d'ids (souvent longueur 1 en var-length)
if len(cell_ids) == 1:
cid = np.asarray(cell_ids[0])[None, :] # -> (1, Npix)
else:
# if >1, try to stack (must have same Npix per element)
cid = np.stack([np.asarray(c) for c in cell_ids], axis=0)
elif torch.is_tensor(cell_ids):
c = cell_ids.detach().cpu().numpy()
cid = c if c.ndim != 1 else c[None, :] # uniformiser en 2D quand B=1
else:
c = np.asarray(cell_ids)
cid = c if c.ndim != 1 else c[None, :]
# cid est maintenant (B, Npix)
idx_nn_eff, w_idx_eff, w_w_eff = self.make_idx_weights_from_cell_ids(
cid, nside, device=device
)
# shapes: (B, Npix, P), (B, Npix, S, P|P), (B, Npix, S, P|P)
P = idx_nn_eff.shape[-1]
S = w_idx_eff.shape[-2] if w_idx_eff.ndim == 4 else 1
# s’assurer des dtypes/devices
idx_nn_eff = torch.as_tensor(idx_nn_eff, device=device, dtype=torch.long)
w_idx_eff = torch.as_tensor(w_idx_eff, device=device, dtype=torch.long)
w_w_eff = torch.as_tensor(w_w_eff, device=device, dtype=dtype)
else:
# Use precomputed (shared for batch)
if self.w_idx is None:
if self.cell_ids.ndim==1:
l_cell=self.cell_ids[None,:]
else:
l_cell=self.cell_ids
idx_nn,w_idx,w_w = self.make_idx_weights_from_cell_ids(
l_cell,
polar=self.polar,
gamma=self.gamma,
device=self.device,
allow_extrapolation=self.allow_extrapolation)
self.idx_nn = idx_nn
self.w_idx = w_idx
self.w_w = w_w
else:
idx_nn = self.idx_nn # (Npix,P)
w_idx = self.w_idx # (Npix,P) or (Npix,S,P)
w_w = self.w_w # (Npix,P) or (Npix,S,P)
#assert idx_nn.ndim == 3 and idx_nn.size(1) == Npix, \
# f"`idx_nn` must be (B,Npix,P) with Npix={Npix}, got {tuple(idx_nn.shape)}"
P = idx_nn.size(-1)
if w_idx.ndim == 3:
S = 1
w_idx_eff = w_idx[:, :, None, :] # (B,Npix,1,P)
w_w_eff = w_w[:, :, None, :] # (B,Npix,1,P)
elif w_idx.ndim == 4:
S = w_idx.size(2)
w_idx_eff = w_idx # (B,Npix,S,P)
w_w_eff = w_w # (B,Npix,S,P)
else:
raise ValueError(f"Unsupported `w_idx` shape {tuple(w_idx.shape)}; expected (Npix,P) or (Npix,S,P)")
idx_nn_eff = idx_nn # (B,Npix,P)
# ---- 1) Gather neighbor values from im along Npix -> (B, C_i, Npix, P)
rim = torch.take_along_dim(
im.unsqueeze(-1), # (B, C_i, Npix, 1)
idx_nn_eff[:, None, :, :], # (B, 1, Npix, P)
dim=2
)
# ---- 2) Normalize ww to (B, C_i, C_o, M, S)
if ww.ndim == 3:
C_i_w, C_o, M = ww.shape
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
ww_eff = ww.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, -1, S)
elif ww.ndim == 4:
if ww.shape[0] == C_i and ww.shape[1] != C_i:
# (C_i, C_o, M, S)
C_i_w, C_o, M, S_w = ww.shape
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
ww_eff = ww.unsqueeze(0).expand(B, -1, -1, -1, -1)
elif ww.shape[0] == B:
# (B, C_i, C_o, M)
_, C_i_w, C_o, M = ww.shape
assert C_i_w == C_i, f"ww C_i mismatch: {C_i_w} vs im {C_i}"
ww_eff = ww.unsqueeze(-1).expand(-1, -1, -1, -1, S)
else:
raise ValueError(f"Ambiguous 4D ww shape {tuple(ww.shape)}; expected (C_i,C_o,M,S) or (B,C_i,C_o,M)")
elif ww.ndim == 5:
# (B, C_i, C_o, M, S)
assert ww.shape[0] == B and ww.shape[1] == C_i, "ww batch/C_i mismatch"
_, _, _, M, S_w = ww.shape
assert S_w == S, f"ww S mismatch: {S_w} vs w_idx S {S}"
ww_eff = ww
else:
raise ValueError(f"Unsupported ww shape {tuple(ww.shape)}")
# --- Sanitize shapes: ensure w_idx_eff / w_w_eff == (B, Npix, S, P)
# ---- 3) Gather along M using w_idx_eff -> (B, C_i, C_o, Npix, S, P)
idx_exp = w_idx_eff[:, None, None, :, :, :] # (B,1,1,Npix,S,P)
rw = torch.take_along_dim(
ww_eff.unsqueeze(-1), # (B,C_i,C_o,M,S,1)
idx_exp, # (B,1,1,Npix,S,P)
dim=3 # gather along M
) # -> (B, C_i, C_o, Npix, S, P)
# ---- 4) Apply extra neighbor weights ----
rw = rw * w_w_eff[:, None, None, :, :, :] # (B, C_i, C_o, Npix, S, P)
# ---- 5) Combine neighbor values and weights ----
rim_exp = rim[:, :, None, :, None, :] # (B, C_i, 1, Npix, 1, P)
out_ci = (rim_exp * rw).sum(dim=-1) # sum over P -> (B, C_i, C_o, Npix, S)
out_ci = out_ci.sum(dim=-1) # sum over S -> (B, C_i, C_o, Npix)
out = out_ci.sum(dim=1) # sum over C_i -> (B, C_o, Npix)
return out
def _to_numpy_1d(self, ids):
"""Return a 1D numpy array of int64 for a single set of cell ids."""
import numpy as np, torch
if isinstance(ids, np.ndarray):
return ids.reshape(-1).astype(np.int64, copy=False)
if torch.is_tensor(ids):
return ids.detach().cpu().to(torch.long).view(-1).numpy()
# python list/tuple of ints
return np.asarray(ids, dtype=np.int64).reshape(-1)
def _is_varlength_batch(self, ids):
"""
True if ids is a list/tuple of per-sample id arrays (var-length batch).
False if ids is a single array/tensor of ids (shared for whole batch).
"""
import numpy as np, torch
if isinstance(ids, (list, tuple)):
return True
if isinstance(ids, np.ndarray) and ids.ndim == 2:
# This would be a dense (B, Npix) matrix -> NOT var-length list
return False
if torch.is_tensor(ids) and ids.dim() == 2:
return False
return False
[docs]
def Down(self, im, cell_ids=None, nside=None,max_poll=False):
"""
If `cell_ids` is a single set of ids -> return a single (Tensor, Tensor).
If `cell_ids` is a list (var-length) -> return (list[Tensor], list[Tensor]).
"""
if self.f is None:
if self.dtype==torch.float64:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
else:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
if cell_ids is None:
dim,cdim = self.f.ud_grade_2(im,cell_ids=self.cell_ids,nside=self.nside,max_poll=False)
return dim,cdim
if nside is None:
nside = self.nside
# var-length mode: list/tuple of ids, one per sample
if self._is_varlength_batch(cell_ids):
outs, outs_ids = [], []
B = len(cell_ids)
for b in range(B):
cid_b = self._to_numpy_1d(cell_ids[b])
# extract the right sample from `im`
if torch.is_tensor(im):
xb = im[b:b+1] # (1, C, N_b)
yb, ids_b = self.f.ud_grade_2(xb, cell_ids=cid_b, nside=nside,max_poll=max_poll)
outs.append(yb.squeeze(0)) # (C, N_b')
else:
# if im is already a list of (C, N_b)
xb = im[b]
yb, ids_b = self.f.ud_grade_2(xb[None, ...], cell_ids=cid_b, nside=nside,max_poll=max_poll)
outs.append(yb.squeeze(0))
outs_ids.append(torch.as_tensor(ids_b, device=outs[-1].device, dtype=torch.long))
return outs, outs_ids
# grille commune (un seul vecteur d'ids)
cid = self._to_numpy_1d(cell_ids)
return self.f.ud_grade_2(im, cell_ids=cid, nside=nside,max_poll=False)
[docs]
def Up(self, im, cell_ids=None, nside=None, o_cell_ids=None):
"""
If `cell_ids` / `o_cell_ids` are single arrays -> return Tensor.
If they are lists (var-length per sample) -> return list[Tensor].
"""
if self.f is None:
if self.dtype==torch.float64:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
else:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
if cell_ids is None:
dim = self.f.up_grade(im,self.nside*2,cell_ids=self.cell_ids,nside=self.nside)
return dim
if nside is None:
nside = self.nside
# var-length: parallel lists
if self._is_varlength_batch(cell_ids):
assert isinstance(o_cell_ids, (list, tuple)) and len(o_cell_ids) == len(cell_ids), \
"In var-length mode, `o_cell_ids` must be a list with same length as `cell_ids`."
outs = []
B = len(cell_ids)
for b in range(B):
cid_b = self._to_numpy_1d(cell_ids[b]) # coarse ids
ocid_b = self._to_numpy_1d(o_cell_ids[b]) # fine ids
if torch.is_tensor(im):
xb = im[b:b+1] # (1, C, N_b_coarse)
yb = self.f.up_grade(xb, nside*2, cell_ids=cid_b, nside=nside,
o_cell_ids=ocid_b, force_init_index=True)
outs.append(yb.squeeze(0)) # (C, N_b_fine)
else:
xb = im[b] # (C, N_b_coarse)
yb = self.f.up_grade(xb[None, ...], nside*2, cell_ids=cid_b, nside=nside,
o_cell_ids=ocid_b, force_init_index=True)
outs.append(yb.squeeze(0))
return outs
# grille commune
cid = self._to_numpy_1d(cell_ids)
ocid = self._to_numpy_1d(o_cell_ids) if o_cell_ids is not None else None
return self.f.up_grade(im, nside*2, cell_ids=cid, nside=nside,
o_cell_ids=ocid, force_init_index=True)
[docs]
def to_tensor(self,x):
if self.f is None:
if self.dtype==torch.float64:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
else:
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
return self.f.backend.bk_cast(x)
[docs]
def to_numpy(self,x):
if isinstance(x,np.ndarray):
return x
return x.cpu().numpy()