foscat.BkTorch ============== .. py:module:: foscat.BkTorch Classes ------- .. autoapisummary:: foscat.BkTorch.BkTorch Module Contents --------------- .. py:class:: BkTorch(*args, **kwargs) Bases: :py:obj:`foscat.BkBase.BackendBase` .. py:attribute:: backend .. py:attribute:: device .. py:attribute:: float64 :value: Ellipsis .. py:attribute:: float32 :value: Ellipsis .. py:attribute:: int64 :value: Ellipsis .. py:attribute:: int32 :value: Ellipsis .. py:attribute:: complex64 :value: Ellipsis .. py:attribute:: complex128 :value: Ellipsis .. py:attribute:: gpulist .. py:attribute:: ngpu :value: 1 .. py:attribute:: torch_device .. py:method:: downsample_mean_2x2(tim: torch.Tensor) -> torch.Tensor Average-pool tensor tim over non-overlapping 2x2 spatial blocks. :Parameters: **tim** (:py:class:`torch.Tensor`) -- Tensor of shape [a, N1, N2, b]. :returns: :py:class:`torch.Tensor` -- Downsampled tensor of shape [a, N1//2, N2//2, b], each element being the mean of a 2x2 block. .. py:method:: downsample_median_2x2(tim: torch.Tensor) -> torch.Tensor 2x2 block median downsampling on spatial axes (N1, N2). Input: tim: [a, N1, N2, b] (real or complex) Output: out: [a, N1//2, N2//2, b] each value is the median over the corresponding 2x2 block. - For complex inputs: median is taken by sorting the 4 values by |.|, returning the complex sample at the lower median rank. .. py:method:: downsample_mean_1d(tim: torch.Tensor) -> torch.Tensor Downsample tensor tim [a, N1] by averaging non-overlapping 2-element blocks. Output shape: [a, N1//2] .. py:method:: downsample_median_1d(tim: torch.Tensor) -> torch.Tensor Downsample tensor tim [a, N1] by taking the median of non-overlapping pairs (2 values). Output shape: [a, N1//2] - For real inputs: median of the two values. - For complex inputs: pick the complex value with the smallest |.| among the two. .. py:method:: binned_mean_old(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan')) Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2. Parameters ---------- data : torch.Tensor | np.ndarray Shape [..., N] or [B, ..., N]. cell_ids : torch.LongTensor | np.ndarray Shape [N] or [B, N] (nested indexing at the child resolution). reduce : {"mean","max"}, default "mean" Aggregation to apply within each parent group of 4 children. padded : bool, default False Only used when `cell_ids` is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask. fill_value : float, default NaN Padding value when `padded=True`. Returns ------- # same as existing doc, but the value is a mean (reduce="mean") # ou un maximum (reduce="max"). .. py:method:: binned_mean(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan')) Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2. :Parameters: * **data** (:py:class:`torch.Tensor | np.ndarray`) -- Shape [..., N] or [B, ..., N]. * **cell_ids** (:py:class:`torch.LongTensor | np.ndarray`) -- Shape [N] or [B, N] (nested indexing at the child resolution). * **reduce** (``{"mean","max","median"}``, *default* ``"mean"``) -- Aggregation within each parent group of 4 children. * **padded** (:py:class:`bool`, *default* :py:obj:`False`) -- Only when `cell_ids` is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask. * **fill_value** (:py:class:`float`, *default* ``NaN``) -- Padding value when `padded=True`. :returns: :py:class:`As in your original function`, :py:class:`with aggregation set by `reduce`.` .. py:method:: average_by_cell_group(cell_ids) data: tensor of shape [..., N, ...] (ex: [B, N, C]) cell_ids: tensor of shape [N] Returns: mean_data of shape [..., G, ...] where G = number of unique cell_ids//4 .. py:method:: bk_masked_median(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12) Masked geometric median over the last axis using Weiszfeld iteration (1D case). :Parameters: * **x** (:py:class:`torch.Tensor`) -- Shape [a, b, c, N]. Can be real or complex. * **mask** (:py:class:`torch.Tensor`) -- Binary mask of shape [a, b, 1, N]; broadcast across 'c'. * **max_iter** (:py:class:`int`) -- Max number of Weiszfeld iterations. * **tol** (:py:class:`float`) -- Convergence tolerance on the max absolute update per voxel. * **eps** (:py:class:`float`) -- Small value to avoid division-by-zero in the weights. :returns: * **med** (:py:class:`torch.Tensor`, :py:class:`shape [a`, :py:class:`b`, :py:class:`c]`) -- Geometric median of x along the last axis where mask == 1. - For complex x: distances use the complex magnitude |x - y|. The returned median is complex. * **med2** (:py:class:`torch.Tensor`, :py:class:`shape [a`, :py:class:`b`, :py:class:`c]`) -- Geometric median of squared values along the last axis where mask == 1. - If x is real : median of x**2 (real). - If x is complex : median of |x|**2 (real). .. py:method:: bk_masked_median_2d_weiszfeld(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12) Masked geometric median over 2D spatial axes using Weiszfeld iteration. :Parameters: * **x** (:py:class:`torch.Tensor`) -- Input of shape [a, b, c, N1, N2]. Can be real or complex. * **mask** (:py:class:`torch.Tensor`) -- Binary mask of shape [a, b, 1, N1, N2]; broadcasted over 'c'. * **max_iter** (:py:class:`int`) -- Maximum number of Weiszfeld iterations. * **tol** (:py:class:`float`) -- Stopping tolerance on the max absolute update per voxel. * **eps** (:py:class:`float`) -- Small positive value to avoid division by zero in weights. :returns: * **med** (:py:class:`torch.Tensor`, :py:class:`shape [a`, :py:class:`b`, :py:class:`c]`) -- Geometric median of x over (N1, N2) where mask == 1. - If x is complex, distances are magnitudes |x - y| in the complex plane, and the returned value is the complex sample estimate (not its magnitude). * **med2** (:py:class:`torch.Tensor`, :py:class:`shape [a`, :py:class:`b`, :py:class:`c]`) -- Geometric median of squared values over (N1, N2) where mask == 1. - If x is real : median of x**2 (via Weiszfeld in 1D). - If x is complex : median of |x|**2 (real, non-negative). .. admonition:: Notes - Voxels with zero valid samples return NaN (NaN+NaNj for complex med). - Weiszfeld update: y_{k+1} = sum_i w_i x_i / sum_i w_i with w_i = 1 / ||x_i - y_k||. Here ||.|| is |.| for real numbers and the complex magnitude for complex numbers. .. py:method:: bk_len(S) .. py:method:: bk_SparseTensor(indice, w, dense_shape=[]) .. py:method:: bk_stack(list, axis=0) .. py:method:: bk_sparse_dense_matmul(smat, mat) .. py:method:: conv2d(x, w) Perform 2D convolution using PyTorch format. Args: x: Tensor of shape [..., Nx, Ny] – input w: Tensor of shape [O_c, wx, wy] – conv weights Returns: Tensor of shape [..., O_c, Nx, Ny] .. py:method:: conv1d(x, w, strides=[1, 1, 1], padding='SAME') Performs 1D convolution along the last axis of a 2D tensor x[n, m] with kernel w[K]. Parameters: - x: torch.Tensor of shape [n, m] - w: torch.Tensor of shape [K] - strides: list of 3 ints; only strides[1] (along axis -1) is used - padding: "SAME" or "VALID" Returns: - torch.Tensor of shape [n, m] (if SAME) or smaller (if VALID) .. py:method:: bk_threshold(x, threshold, greater=True) .. py:method:: bk_maximum(x1, x2) .. py:method:: bk_device(device_name) .. py:method:: bk_ones(shape, dtype=None) .. py:method:: bk_conv1d(x, w) .. py:method:: bk_flattenR(x) .. py:method:: bk_flatten(x) .. py:method:: bk_resize_image(x, shape) .. py:method:: bk_L1(x) .. py:method:: bk_square_comp(x) .. py:method:: bk_reduce_sum(data, axis=None) .. py:method:: bk_size(data) .. py:method:: constant(data) .. py:method:: bk_reduce_mean(data, axis=None) .. py:method:: bk_reduce_median(data, axis=None) .. py:method:: bk_reduce_min(data, axis=None) .. py:method:: bk_random_seed(value) .. py:method:: bk_random_uniform(shape) .. py:method:: bk_reduce_std(data, axis=None) .. py:method:: bk_sqrt(data) .. py:method:: bk_abs(data) .. py:method:: bk_is_complex(data) .. py:method:: bk_distcomp(data) .. py:method:: bk_norm(data) .. py:method:: bk_square(data) .. py:method:: bk_log(data) .. py:method:: bk_matmul(a, b) .. py:method:: bk_tensor(data) .. py:method:: bk_shape_tensor(shape) .. py:method:: bk_complex(real, imag) .. py:method:: bk_exp(data) .. py:method:: bk_min(data) .. py:method:: bk_argmin(data) .. py:method:: bk_tanh(data) .. py:method:: bk_max(data) .. py:method:: bk_argmax(data) .. py:method:: bk_reshape(data, shape) .. py:method:: bk_repeat(data, nn, axis=0) .. py:method:: bk_tile(data, nn, axis=0) .. py:method:: bk_roll(data, nn, axis=0) .. py:method:: bk_expand_dims(data, axis=0) .. py:method:: bk_transpose(data, thelist) .. py:method:: bk_concat(data, axis=None) .. py:method:: bk_zeros(shape, dtype=None) .. py:method:: bk_gather(data, idx, axis=0) .. py:method:: bk_reverse(data, axis=0) .. py:method:: bk_fft(data) .. py:method:: bk_fftn(data, dim=None) .. py:method:: bk_ifftn(data, dim=None, norm=None) .. py:method:: bk_rfft(data) .. py:method:: bk_irfft(data) .. py:method:: bk_conjugate(data) .. py:method:: bk_real(data) .. py:method:: bk_imag(data) .. py:method:: bk_relu(x) .. py:method:: bk_clip_by_value(x, xmin, xmax) .. py:method:: bk_cast(x) .. py:method:: bk_variable(x) .. py:method:: bk_assign(x, y) .. py:method:: bk_constant(x) .. py:method:: bk_cos(x) .. py:method:: bk_sin(x) .. py:method:: bk_arctan2(c, s) .. py:method:: bk_empty(list) .. py:method:: to_numpy(x)