foscat.BkTorch#

Classes#

Module Contents#

class foscat.BkTorch.BkTorch(*args, **kwargs)[source]#

Bases: foscat.BkBase.BackendBase

backend#
device#
float64 = Ellipsis#
float32 = Ellipsis#
int64 = Ellipsis#
int32 = Ellipsis#
complex64 = Ellipsis#
complex128 = Ellipsis#
gpulist#
ngpu = 1#
torch_device#
downsample_mean_2x2(tim: torch.Tensor) torch.Tensor[source]#

Average-pool tensor tim over non-overlapping 2x2 spatial blocks.

Parameters:

tim (torch.Tensor) – Tensor of shape [a, N1, N2, b].

Returns:

torch.Tensor – Downsampled tensor of shape [a, N1//2, N2//2, b], each element being the mean of a 2x2 block.

downsample_median_2x2(tim: torch.Tensor) torch.Tensor[source]#

2x2 block median downsampling on spatial axes (N1, N2).

Input:

tim: [a, N1, N2, b] (real or complex)

Output:
out: [a, N1//2, N2//2, b]

each value is the median over the corresponding 2x2 block.

  • For complex inputs: median is taken by sorting the 4 values by |.|, returning the complex sample at the lower median rank.

downsample_mean_1d(tim: torch.Tensor) torch.Tensor[source]#

Downsample tensor tim [a, N1] by averaging non-overlapping 2-element blocks. Output shape: [a, N1//2]

downsample_median_1d(tim: torch.Tensor) torch.Tensor[source]#

Downsample tensor tim [a, N1] by taking the median of non-overlapping pairs (2 values). Output shape: [a, N1//2]

  • For real inputs: median of the two values.

  • For complex inputs: pick the complex value with the smallest |.| among the two.

binned_mean_old(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan'))[source]#

Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.

Parameters

data : torch.Tensor | np.ndarray Shape […, N] or [B, …, N]. cell_ids : torch.LongTensor | np.ndarray Shape [N] or [B, N] (nested indexing at the child resolution). reduce : {“mean”,”max”}, default “mean” Aggregation to apply within each parent group of 4 children. padded : bool, default False Only used when cell_ids is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask. fill_value : float, default NaN Padding value when padded=True.

Returns

# same as existing doc, but the value is a mean (reduce=”mean”) # ou un maximum (reduce=”max”).

binned_mean(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan'))[source]#

Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.

Parameters:
  • data (torch.Tensor | np.ndarray) – Shape […, N] or [B, …, N].

  • cell_ids (torch.LongTensor | np.ndarray) – Shape [N] or [B, N] (nested indexing at the child resolution).

  • reduce ({"mean","max","median"}, default "mean") – Aggregation within each parent group of 4 children.

  • padded (bool, default False) – Only when cell_ids is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask.

  • fill_value (float, default NaN) – Padding value when padded=True.

Returns:

As in your original function, with aggregation set by `reduce.`

average_by_cell_group(cell_ids)[source]#

data: tensor of shape […, N, …] (ex: [B, N, C]) cell_ids: tensor of shape [N] Returns: mean_data of shape […, G, …] where G = number of unique cell_ids//4

bk_masked_median(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12)[source]#

Masked geometric median over the last axis using Weiszfeld iteration (1D case).

Parameters:
  • x (torch.Tensor) – Shape [a, b, c, N]. Can be real or complex.

  • mask (torch.Tensor) – Binary mask of shape [a, b, 1, N]; broadcast across ‘c’.

  • max_iter (int) – Max number of Weiszfeld iterations.

  • tol (float) – Convergence tolerance on the max absolute update per voxel.

  • eps (float) – Small value to avoid division-by-zero in the weights.

Returns:

  • med (torch.Tensor, shape [a, b, c]) – Geometric median of x along the last axis where mask == 1. - For complex x: distances use the complex magnitude |x - y|.

    The returned median is complex.

  • med2 (torch.Tensor, shape [a, b, c]) – Geometric median of squared values along the last axis where mask == 1. - If x is real : median of x**2 (real). - If x is complex : median of |x|**2 (real).

bk_masked_median_2d_weiszfeld(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12)[source]#

Masked geometric median over 2D spatial axes using Weiszfeld iteration.

Parameters:
  • x (torch.Tensor) – Input of shape [a, b, c, N1, N2]. Can be real or complex.

  • mask (torch.Tensor) – Binary mask of shape [a, b, 1, N1, N2]; broadcasted over ‘c’.

  • max_iter (int) – Maximum number of Weiszfeld iterations.

  • tol (float) – Stopping tolerance on the max absolute update per voxel.

  • eps (float) – Small positive value to avoid division by zero in weights.

Returns:

  • med (torch.Tensor, shape [a, b, c]) – Geometric median of x over (N1, N2) where mask == 1. - If x is complex, distances are magnitudes |x - y| in the complex plane,

    and the returned value is the complex sample estimate (not its magnitude).

  • med2 (torch.Tensor, shape [a, b, c]) – Geometric median of squared values over (N1, N2) where mask == 1. - If x is real : median of x**2 (via Weiszfeld in 1D). - If x is complex : median of |x|**2 (real, non-negative).

Notes

  • Voxels with zero valid samples return NaN (NaN+NaNj for complex med).

  • Weiszfeld update: y_{k+1} = sum_i w_i x_i / sum_i w_i with w_i = 1 / ||x_i - y_k||. Here ||.|| is |.| for real numbers and the complex magnitude for complex numbers.

bk_len(S)[source]#
bk_SparseTensor(indice, w, dense_shape=[])[source]#
bk_stack(list, axis=0)[source]#
bk_sparse_dense_matmul(smat, mat)[source]#
conv2d(x, w)[source]#

Perform 2D convolution using PyTorch format.

Args:

x: Tensor of shape […, Nx, Ny] – input w: Tensor of shape [O_c, wx, wy] – conv weights

Returns:

Tensor of shape […, O_c, Nx, Ny]

conv1d(x, w, strides=[1, 1, 1], padding='SAME')[source]#

Performs 1D convolution along the last axis of a 2D tensor x[n, m] with kernel w[K].

Parameters: - x: torch.Tensor of shape [n, m] - w: torch.Tensor of shape [K] - strides: list of 3 ints; only strides[1] (along axis -1) is used - padding: “SAME” or “VALID”

Returns: - torch.Tensor of shape [n, m] (if SAME) or smaller (if VALID)

bk_threshold(x, threshold, greater=True)[source]#
bk_maximum(x1, x2)[source]#
bk_device(device_name)[source]#
bk_ones(shape, dtype=None)[source]#
bk_conv1d(x, w)[source]#
bk_flattenR(x)[source]#
bk_flatten(x)[source]#
bk_resize_image(x, shape)[source]#
bk_L1(x)[source]#
bk_square_comp(x)[source]#
bk_reduce_sum(data, axis=None)[source]#
bk_size(data)[source]#
constant(data)[source]#
bk_reduce_mean(data, axis=None)[source]#
bk_reduce_median(data, axis=None)[source]#
bk_reduce_min(data, axis=None)[source]#
bk_random_seed(value)[source]#
bk_random_uniform(shape)[source]#
bk_reduce_std(data, axis=None)[source]#
bk_sqrt(data)[source]#
bk_abs(data)[source]#
bk_is_complex(data)[source]#
bk_distcomp(data)[source]#
bk_norm(data)[source]#
bk_square(data)[source]#
bk_log(data)[source]#
bk_matmul(a, b)[source]#
bk_tensor(data)[source]#
bk_shape_tensor(shape)[source]#
bk_complex(real, imag)[source]#
bk_exp(data)[source]#
bk_min(data)[source]#
bk_argmin(data)[source]#
bk_tanh(data)[source]#
bk_max(data)[source]#
bk_argmax(data)[source]#
bk_reshape(data, shape)[source]#
bk_repeat(data, nn, axis=0)[source]#
bk_tile(data, nn, axis=0)[source]#
bk_roll(data, nn, axis=0)[source]#
bk_expand_dims(data, axis=0)[source]#
bk_transpose(data, thelist)[source]#
bk_concat(data, axis=None)[source]#
bk_zeros(shape, dtype=None)[source]#
bk_gather(data, idx, axis=0)[source]#
bk_reverse(data, axis=0)[source]#
bk_fft(data)[source]#
bk_fftn(data, dim=None)[source]#
bk_ifftn(data, dim=None, norm=None)[source]#
bk_rfft(data)[source]#
bk_irfft(data)[source]#
bk_conjugate(data)[source]#
bk_real(data)[source]#
bk_imag(data)[source]#
bk_relu(x)[source]#
bk_clip_by_value(x, xmin, xmax)[source]#
bk_cast(x)[source]#
bk_variable(x)[source]#
bk_assign(x, y)[source]#
bk_constant(x)[source]#
bk_cos(x)[source]#
bk_sin(x)[source]#
bk_arctan2(c, s)[source]#
bk_empty(list)[source]#
to_numpy(x)[source]#