foscat.BkTorch#
Classes#
Module Contents#
- class foscat.BkTorch.BkTorch(*args, **kwargs)[source]#
Bases:
foscat.BkBase.BackendBase- backend#
- device#
- float64 = Ellipsis#
- float32 = Ellipsis#
- int64 = Ellipsis#
- int32 = Ellipsis#
- complex64 = Ellipsis#
- complex128 = Ellipsis#
- gpulist#
- ngpu = 1#
- torch_device#
- downsample_mean_2x2(tim: torch.Tensor) torch.Tensor[source]#
Average-pool tensor tim over non-overlapping 2x2 spatial blocks.
- Parameters:
tim (
torch.Tensor) – Tensor of shape [a, N1, N2, b].- Returns:
torch.Tensor– Downsampled tensor of shape [a, N1//2, N2//2, b], each element being the mean of a 2x2 block.
- downsample_median_2x2(tim: torch.Tensor) torch.Tensor[source]#
2x2 block median downsampling on spatial axes (N1, N2).
- Input:
tim: [a, N1, N2, b] (real or complex)
- Output:
- out: [a, N1//2, N2//2, b]
each value is the median over the corresponding 2x2 block.
For complex inputs: median is taken by sorting the 4 values by |.|, returning the complex sample at the lower median rank.
- downsample_mean_1d(tim: torch.Tensor) torch.Tensor[source]#
Downsample tensor tim [a, N1] by averaging non-overlapping 2-element blocks. Output shape: [a, N1//2]
- downsample_median_1d(tim: torch.Tensor) torch.Tensor[source]#
Downsample tensor tim [a, N1] by taking the median of non-overlapping pairs (2 values). Output shape: [a, N1//2]
For real inputs: median of the two values.
For complex inputs: pick the complex value with the smallest |.| among the two.
- binned_mean_old(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan'))[source]#
Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
- Parameters
data : torch.Tensor | np.ndarray Shape […, N] or [B, …, N]. cell_ids : torch.LongTensor | np.ndarray Shape [N] or [B, N] (nested indexing at the child resolution). reduce : {“mean”,”max”}, default “mean” Aggregation to apply within each parent group of 4 children. padded : bool, default False Only used when cell_ids is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask. fill_value : float, default NaN Padding value when padded=True.
- Returns
# same as existing doc, but the value is a mean (reduce=”mean”) # ou un maximum (reduce=”max”).
- binned_mean(data, cell_ids, *, reduce: str = 'mean', padded: bool = False, fill_value: float = float('nan'))[source]#
Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
- Parameters:
data (
torch.Tensor | np.ndarray) – Shape […, N] or [B, …, N].cell_ids (
torch.LongTensor | np.ndarray) – Shape [N] or [B, N] (nested indexing at the child resolution).reduce (
{"mean","max","median"}, default"mean") – Aggregation within each parent group of 4 children.padded (
bool, defaultFalse) – Only when cell_ids is [B, N]. If False, returns ragged Python lists. If True, returns padded tensors + mask.fill_value (
float, defaultNaN) – Padding value when padded=True.
- Returns:
As in your original function,with aggregation set by `reduce.`
- average_by_cell_group(cell_ids)[source]#
data: tensor of shape […, N, …] (ex: [B, N, C]) cell_ids: tensor of shape [N] Returns: mean_data of shape […, G, …] where G = number of unique cell_ids//4
- bk_masked_median(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12)[source]#
Masked geometric median over the last axis using Weiszfeld iteration (1D case).
- Parameters:
x (
torch.Tensor) – Shape [a, b, c, N]. Can be real or complex.mask (
torch.Tensor) – Binary mask of shape [a, b, 1, N]; broadcast across ‘c’.max_iter (
int) – Max number of Weiszfeld iterations.tol (
float) – Convergence tolerance on the max absolute update per voxel.eps (
float) – Small value to avoid division-by-zero in the weights.
- Returns:
med (
torch.Tensor,shape [a,b,c]) – Geometric median of x along the last axis where mask == 1. - For complex x: distances use the complex magnitude |x - y|.The returned median is complex.
med2 (
torch.Tensor,shape [a,b,c]) – Geometric median of squared values along the last axis where mask == 1. - If x is real : median of x**2 (real). - If x is complex : median of |x|**2 (real).
- bk_masked_median_2d_weiszfeld(x: torch.Tensor, mask: torch.Tensor, max_iter: int = 100, tol: float = 1e-06, eps: float = 1e-12)[source]#
Masked geometric median over 2D spatial axes using Weiszfeld iteration.
- Parameters:
x (
torch.Tensor) – Input of shape [a, b, c, N1, N2]. Can be real or complex.mask (
torch.Tensor) – Binary mask of shape [a, b, 1, N1, N2]; broadcasted over ‘c’.max_iter (
int) – Maximum number of Weiszfeld iterations.tol (
float) – Stopping tolerance on the max absolute update per voxel.eps (
float) – Small positive value to avoid division by zero in weights.
- Returns:
med (
torch.Tensor,shape [a,b,c]) – Geometric median of x over (N1, N2) where mask == 1. - If x is complex, distances are magnitudes |x - y| in the complex plane,and the returned value is the complex sample estimate (not its magnitude).
med2 (
torch.Tensor,shape [a,b,c]) – Geometric median of squared values over (N1, N2) where mask == 1. - If x is real : median of x**2 (via Weiszfeld in 1D). - If x is complex : median of |x|**2 (real, non-negative).
Notes
Voxels with zero valid samples return NaN (NaN+NaNj for complex med).
Weiszfeld update: y_{k+1} = sum_i w_i x_i / sum_i w_i with w_i = 1 / ||x_i - y_k||. Here ||.|| is |.| for real numbers and the complex magnitude for complex numbers.
- conv2d(x, w)[source]#
Perform 2D convolution using PyTorch format.
- Args:
x: Tensor of shape […, Nx, Ny] – input w: Tensor of shape [O_c, wx, wy] – conv weights
- Returns:
Tensor of shape […, O_c, Nx, Ny]
- conv1d(x, w, strides=[1, 1, 1], padding='SAME')[source]#
Performs 1D convolution along the last axis of a 2D tensor x[n, m] with kernel w[K].
Parameters: - x: torch.Tensor of shape [n, m] - w: torch.Tensor of shape [K] - strides: list of 3 ints; only strides[1] (along axis -1) is used - padding: “SAME” or “VALID”
Returns: - torch.Tensor of shape [n, m] (if SAME) or smaller (if VALID)