foscat.SphericalStencil ======================= .. py:module:: foscat.SphericalStencil Classes ------- .. autoapisummary:: foscat.SphericalStencil.SphericalStencil Module Contents --------------- .. py:class:: SphericalStencil(nside: int, kernel_sz: int, *, nest: bool = True, cell_ids=None, device=None, dtype=None, n_gauges=1, gauge_type='phi') GPU-accelerated spherical stencil operator for HEALPix convolutions. This class implements three phases: A) Geometry preparation: build local rotated stencil vectors for each target pixel, compute HEALPix neighbor indices and interpolation weights. B) Sparse binding: map neighbor indices/weights to available data samples (sorted ids), and normalize weights. C) Convolution: apply multi-channel kernels to sparse gathered data. Once A+B are prepared, multiple convolutions (C) can be applied efficiently on the GPU. :Parameters: * **nside** (:py:class:`int`) -- HEALPix resolution parameter. * **kernel_sz** (:py:class:`int`) -- Size of local stencil (must be odd, e.g. 3, 5, 7). * **gauge_type** (:py:class:`str`) -- Type of gauge : 'cosmo' use the same definition than https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html 'phi' is define at the pole, could be better for earth observation not using intensivly the pole * **n_gauge** (:py:class:`float`) -- Number of oriented gauges (Default 1). * **blend** (:py:class:`bool`) -- Whether to blend smoothly between axisA and axisB (dual gauge). * **power** (:py:class:`float`) -- Sharpness of blend transition (dual gauge). * **nest** (:py:class:`bool`) -- Use nested ordering if True (default), else ring ordering. * **cell_ids** (:py:class:`np.ndarray | torch.Tensor | None`) -- If given, initialize Step A immediately for these targets. * **device** (:py:class:`torch.device | str | None`) -- Default device (if None, 'cuda' if available else 'cpu'). * **dtype** (:py:class:`torch.dtype | None`) -- Default dtype (float32 if None). .. py:attribute:: nside .. py:attribute:: KERNELSZ .. py:attribute:: P .. py:attribute:: G :value: 1 .. py:attribute:: gauge_type :value: 'phi' .. py:attribute:: nest :value: True .. py:attribute:: device .. py:attribute:: dtype :value: None .. py:attribute:: Kb :value: None .. py:attribute:: idx_t :value: None .. py:attribute:: w_t :value: None .. py:attribute:: ids_sorted_np :value: None .. py:attribute:: pos_safe_t :value: None .. py:attribute:: w_norm_t :value: None .. py:attribute:: present_t :value: None .. py:attribute:: cell_ids_default :value: None .. py:method:: get_interp_weights_from_vec_torch(nside: int, vec, *, nest: bool = True, device=None, dtype=None, chunk_size=1000000) :staticmethod: Torch wrapper for healpy.get_interp_weights using input vectors. :Parameters: * **nside** (:py:class:`int`) -- HEALPix resolution. * **vec** (:py:class:`torch.Tensor (...,3)`) -- Direction vectors (not necessarily normalized). * **nest** (:py:class:`bool`) -- Nested ordering if True (default). * **device, dtype** (:py:class:`Torch device/dtype.`) * **chunk_size** (:py:class:`int`) -- Number of points per healpy call on CPU. :returns: * **idx_t** (:py:class:`LongTensor (4`, :py:class:`*leading)`) * **w_t** (:py:class:`Tensor (4`, :py:class:`*leading)`) .. py:method:: prepare_torch(th, ph, alpha=None, G: int = 1) Prepare rotated stencil and HEALPix neighbors/weights in Torch for *G gauges*. :Parameters: * **th, ph** (:py:class:`array-like`, :py:class:`shape (K,)`) -- Target colatitudes/longitudes. * **alpha** (:py:class:`array-like (K,)` or :py:class:`scalar` or :py:obj:`None`) -- Base gauge angle about the local normal at each target. If None -> 0. For each gauge g in [0..G-1], the effective angle is alpha + g*pi/G. * **G** (:py:class:`int (>=1)`) -- Number of gauges to generate per target. * **Side effects** * **------------** * **Sets** -- - self.Kb = K - self.G = G - self.idx_t_multi : (G, 4, K*P) LongTensor (neighbors per gauge) - self.w_t_multi : (G, 4, K*P) Tensor (weights per gauge) - For backward compat when G==1: self.idx_t : (4, K*P) self.w_t : (4, K*P) :returns: * **idx_t_multi** (:py:class:`torch.LongTensor`, :py:class:`shape (G`, ``4``, :py:class:`K*P)`) * **w_t_multi** (:py:class:`torch.Tensor`, shape (G, ``4``, :py:class:`K*P)`) .. py:method:: bind_support_torch_multi(ids_sorted_np, *, device=None, dtype=None) Multi-gauge sparse binding (Step B) WITH 'reduced domain' logic: - weights of out-of-domain neighbours set to 0 - column renormalisation to 1 - si colonne vide: fallback sur le pixel cible (centre du stencil) Produit: self.pos_safe_t_multi : (G, 4, K*P) self.w_norm_t_multi : (G, 4, K*P) self.present_t_multi : (G, 4, K*P) .. py:method:: bind_support_torch(ids_sorted_np, *, device=None, dtype=None) Single-gauge sparse binding (Step B) WITH 'reduced domain' logic: - weights of out-of-domain neighbours set to 0 - column renormalisation to 1 - si colonne vide: fallback sur le pixel cible (centre du stencil) .. py:method:: apply_multi(data_sorted_t: torch.Tensor, kernel_t: torch.Tensor) Apply multi-gauge convolution. Inputs ------ data_sorted_t : (B, Ci, K) torch.Tensor on self.device/self.dtype kernel_t : either - (Ci, Co_g, P) : shared kernel for all gauges - (G, Ci, Co_g, P) : per-gauge kernels :returns: **out** (:py:class:`(B`, :py:class:`G*Co_g`, :py:class:`K) torch.Tensor`) .. py:method:: apply(data_sorted_t, kernel_t) Apply the (Ci,Co,P) kernel to batched sparse data (B,Ci,K) using precomputed pos_safe and w_norm. Runs fully on GPU. :Parameters: * **data_sorted_t** (:py:class:`torch.Tensor (B,Ci,K)`) -- Input data aligned with ids_sorted. * **kernel_t** (:py:class:`torch.Tensor (Ci,Co,P)`) -- Convolution kernel. :returns: **out** (:py:class:`torch.Tensor (B,Co,K)`) .. py:method:: Convol_torch(im, ww, cell_ids=None, nside=None) Batched KERNELSZ x KERNELSZ aggregation (dispatcher). Supports: - im: Tensor (B, Ci, K) with * cell_ids is None -> use cached targets (fast path) * cell_ids is 1D (K,) -> one shared grid for whole batch * cell_ids is 2D (B, K) -> per-sample grids, same length; returns (B, Co, K) * cell_ids is list/tuple -> per-sample grids (var-length allowed) - im: list/tuple of Tensors, each (Ci, K_b) with cell_ids list/tuple .. admonition:: Notes - Kernel shapes accepted: * single/multi shared: (Ci, Co_g, P) * per-gauge kernels: (G, Ci, Co_g, P) The low-level _Convol_Torch will choose between apply/apply_multi depending on the class state (G>1 and multi-bind present). .. py:method:: make_matrix(kernel: torch.Tensor, cell_ids=None, *, return_sparse_tensor: bool = False, chunk_k: int = 4096) Build the sparse COO matrix M such that applying M to vec(data) reproduces the spherical convolution performed by Convol_torch/_Convol_Torch. Supports single- and multi-gauge: - kernel shape (Ci, Co_g, P) -> shared across G gauges, output Co = G*Co_g - kernel shape (G, Ci, Co_g, P) -> per-gauge kernels, same output Co = G*Co_g :Parameters: * **kernel** (:py:class:`torch.Tensor`) -- (Ci, Co_g, P) or (G, Ci, Co_g, P) with P = kernel_sz**2. Must be on the device/dtype where you want the resulting matrix. * **cell_ids** (:py:class:`array-like` of :py:class:`shape (K,)` or :py:class:`torch.Tensor`, *optional*) -- Target pixel IDs (NESTED if self.nest=True). If None, uses the grid already cached in the class (fast path). If provided, we prepare geometry & sparse binding for these ids. * **return_sparse_tensor** (:py:class:`bool`, *default* :py:obj:`False`) -- If True, return a coalesced torch.sparse_coo_tensor of shape (Co*K, Ci*K). Else, return (weights, indices, shape) where: - indices is a LongTensor of shape (2, nnz) with [row; col] - weights is a Tensor of shape (nnz,) - shape is the (rows, cols) tuple * **chunk_k** (:py:class:`int`, *default* ``4096``) -- Chunk size over target pixels to limit peak memory. :returns: * :py:class:`If return_sparse_tensor` -- M : torch.sparse_coo_tensor of shape (Co*K, Ci*K), coalesced * :py:class:`else` -- weights : torch.Tensor (nnz,) indices : torch.LongTensor (2, nnz) with [row; col] shape : tuple[int, int] (Co*K, Ci*K) .. admonition:: Notes - The resulting matrix implements the same interpolation-and-mixing as the GPU path (gather 4 neighbors -> normalize -> apply spatial+channel kernel), and matches the output of Convol_torch for the same (kernel, cell_ids). - For multi-gauge, rows are grouped as concatenated gauges: first all Co_g channels for gauge 0 over all K, then gauge 1, etc. .. py:method:: to_tensor(x) .. py:method:: to_numpy(x)