import sys
import numpy as np
[docs]
class foscat_backend:
def __init__(self, name, mpi_rank=0, all_type="float64", gpupos=0, silent=False):
self.TENSORFLOW = 1
self.TORCH = 2
self.NUMPY = 3
# table use to compute the iso orientation rotation
self._iso_orient = {}
self._iso_orient_T = {}
self._iso_orient_C = {}
self._iso_orient_C_T = {}
self._fft_1_orient = {}
self._fft_1_orient_C = {}
self._fft_2_orient = {}
self._fft_2_orient_C = {}
self._fft_3_orient = {}
self._fft_3_orient_C = {}
self.BACKEND = name
if name not in ["tensorflow", "torch", "numpy"]:
print('Backend "%s" not yet implemented' % (name))
print(" Choose inside the next 3 available backends :")
print(" - tensorflow")
print(" - torch")
print(" - numpy (Impossible to do synthesis using numpy)")
return None
if self.BACKEND == "tensorflow":
import tensorflow as tf
self.backend = tf
self.BACKEND = self.TENSORFLOW
# tf.config.threading.set_inter_op_parallelism_threads(1)
# tf.config.threading.set_intra_op_parallelism_threads(1)
self.tf_function = tf.function
if self.BACKEND == "torch":
import torch
self.torch_device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.BACKEND = self.TORCH
self.backend = torch
self.tf_function = self.tf_loc_function
if self.BACKEND == "numpy":
self.BACKEND = self.NUMPY
self.backend = np
import scipy as scipy
self.scipy = scipy
self.tf_function = self.tf_loc_function
self.float64 = self.backend.float64
self.float32 = self.backend.float32
self.int64 = self.backend.int64
self.int32 = self.backend.int32
self.complex64 = self.backend.complex128
self.complex128 = self.backend.complex64
if all_type == "float32":
self.all_bk_type = self.backend.float32
self.all_cbk_type = self.backend.complex64
else:
if all_type == "float64":
self.all_type = "float64"
self.all_bk_type = self.backend.float64
self.all_cbk_type = self.backend.complex128
else:
print("ERROR INIT FOCUS ", all_type, " should be float32 or float64")
return None
# ===========================================================================
# INIT
if mpi_rank == 0:
if self.BACKEND == self.TENSORFLOW and not silent:
print(
"Num GPUs Available: ",
len(self.backend.config.experimental.list_physical_devices("GPU")),
)
sys.stdout.flush()
if self.BACKEND == self.TENSORFLOW:
self.backend.debugging.set_log_device_placement(False)
self.backend.config.set_soft_device_placement(True)
gpus = self.backend.config.experimental.list_physical_devices("GPU")
if self.BACKEND == self.TORCH:
gpus = torch.cuda.is_available()
if self.BACKEND == self.NUMPY:
gpus = []
gpuname = "CPU:0"
self.gpulist = {}
self.gpulist[0] = gpuname
self.ngpu = 1
if gpus:
try:
if self.BACKEND == self.TENSORFLOW:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
self.backend.config.experimental.set_memory_growth(gpu, True)
logical_gpus = (
self.backend.config.experimental.list_logical_devices("GPU")
)
print(
len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs"
)
sys.stdout.flush()
self.ngpu = len(logical_gpus)
gpuname = logical_gpus[gpupos % self.ngpu].name
self.gpulist = {}
for i in range(self.ngpu):
self.gpulist[i] = logical_gpus[i].name
if self.BACKEND == self.TORCH:
self.ngpu = torch.cuda.device_count()
self.gpulist = {}
for k in range(self.ngpu):
self.gpulist[k] = torch.cuda.get_device_name(0)
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
[docs]
def tf_loc_function(self, func):
return func
[docs]
def calc_iso_orient(self, norient):
tmp = np.zeros([norient * norient, norient])
for i in range(norient):
for j in range(norient):
tmp[j * norient + (j + i) % norient, i] = 0.25
self._iso_orient[norient] = self.constant(self.bk_cast(tmp))
self._iso_orient_T[norient] = self.constant(self.bk_cast(4 * tmp.T))
self._iso_orient_C[norient] = self.bk_complex(
self._iso_orient[norient], 0 * self._iso_orient[norient]
)
self._iso_orient_C_T[norient] = self.bk_complex(
self._iso_orient_T[norient], 0 * self._iso_orient_T[norient]
)
[docs]
def calc_fft_orient(self, norient, nharm, imaginary):
x = np.arange(norient) / norient * 2 * np.pi
if imaginary:
tmp = np.zeros([norient, 1 + nharm * 2])
tmp[:, 0] = 1.0
for k in range(nharm):
tmp[:, k * 2 + 1] = np.cos(x * (k + 1))
tmp[:, k * 2 + 2] = np.sin(x * (k + 1))
self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(tmp)
)
self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_1_orient[(norient, nharm, imaginary)],
0 * self._fft_1_orient[(norient, nharm, imaginary)],
)
else:
tmp = np.zeros([norient, 1 + nharm])
for k in range(nharm + 1):
tmp[:, k] = np.cos(x * k)
self._fft_1_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(tmp)
)
self._fft_1_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_1_orient[(norient, nharm, imaginary)],
0 * self._fft_1_orient[(norient, nharm, imaginary)],
)
x = np.repeat(x, norient).reshape(norient, norient)
if imaginary:
tmp = np.zeros([norient, norient, (1 + nharm * 2), (1 + nharm * 2)])
tmp[:, :, 0, 0] = 1.0
for k in range(nharm):
tmp[:, :, k * 2 + 1, 0] = np.cos(x * (k + 1))
tmp[:, :, k * 2 + 2, 0] = np.sin(x * (k + 1))
tmp[:, :, 0, k * 2 + 1] = np.cos((x.T) * (k + 1))
tmp[:, :, 0, k * 2 + 2] = np.sin((x.T) * (k + 1))
for l_orient in range(nharm):
tmp[:, :, k * 2 + 1, l_orient * 2 + 1] = np.cos(
x * (k + 1)
) * np.cos((x.T) * (l_orient + 1))
tmp[:, :, k * 2 + 2, l_orient * 2 + 1] = np.sin(
x * (k + 1)
) * np.cos((x.T) * (l_orient + 1))
tmp[:, :, k * 2 + 1, l_orient * 2 + 2] = np.cos(
x * (k + 1)
) * np.sin((x.T) * (l_orient + 1))
tmp[:, :, k * 2 + 2, l_orient * 2 + 2] = np.sin(
x * (k + 1)
) * np.sin((x.T) * (l_orient + 1))
self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(
tmp.reshape(norient * norient, (1 + 2 * nharm) * (1 + 2 * nharm))
)
)
self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_2_orient[(norient, nharm, imaginary)],
0 * self._fft_2_orient[(norient, nharm, imaginary)],
)
else:
tmp = np.zeros([norient, norient, (1 + nharm), (1 + nharm)])
for k in range(nharm + 1):
for l_orient in range(nharm + 1):
tmp[:, :, k, l_orient] = np.cos(x * k) * np.cos((x.T) * l_orient)
self._fft_2_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(tmp.reshape(norient * norient, (1 + nharm) * (1 + nharm)))
)
self._fft_2_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_2_orient[(norient, nharm, imaginary)],
0 * self._fft_2_orient[(norient, nharm, imaginary)],
)
x = np.arange(norient) / norient * 2 * np.pi
xx = np.zeros([norient, norient, norient])
yy = np.zeros([norient, norient, norient])
zz = np.zeros([norient, norient, norient])
for i in range(norient):
for j in range(norient):
xx[:, i, j] = x
yy[i, :, j] = x
zz[i, j, :] = x
if imaginary:
tmp = np.ones(
[
norient,
norient,
norient,
(1 + nharm * 2),
(1 + nharm * 2),
(1 + nharm * 2),
]
)
for k in range(nharm):
tmp[:, :, :, k * 2 + 1, 0, 0] = np.cos(xx * (k + 1))
tmp[:, :, :, 0, k * 2 + 1, 0] = np.cos(yy * (k + 1))
tmp[:, :, :, 0, 0, k * 2 + 1] = np.cos(zz * (k + 1))
tmp[:, :, :, k * 2 + 2, 0, 0] = np.sin(xx * (k + 1))
tmp[:, :, :, 0, k * 2 + 2, 0] = np.sin(yy * (k + 1))
tmp[:, :, :, 0, 0, k * 2 + 2] = np.sin(zz * (k + 1))
for l_orient in range(nharm):
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, 0] = np.cos(
xx * (k + 1)
) * np.cos(yy * (l_orient + 1))
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, 0] = np.cos(
xx * (k + 1)
) * np.sin(yy * (l_orient + 1))
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, 0] = np.sin(
xx * (k + 1)
) * np.cos(yy * (l_orient + 1))
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, 0] = np.sin(
xx * (k + 1)
) * np.sin(yy * (l_orient + 1))
tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 1] = np.cos(
xx * (k + 1)
) * np.cos(zz * (l_orient + 1))
tmp[:, :, :, k * 2 + 1, 0, l_orient * 2 + 2] = np.cos(
xx * (k + 1)
) * np.sin(zz * (l_orient + 1))
tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 1] = np.sin(
xx * (k + 1)
) * np.cos(zz * (l_orient + 1))
tmp[:, :, :, k * 2 + 2, 0, l_orient * 2 + 2] = np.sin(
xx * (k + 1)
) * np.sin(zz * (l_orient + 1))
tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 1] = np.cos(
yy * (k + 1)
) * np.cos(zz * (l_orient + 1))
tmp[:, :, :, 0, k * 2 + 1, l_orient * 2 + 2] = np.cos(
yy * (k + 1)
) * np.sin(zz * (l_orient + 1))
tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 1] = np.sin(
yy * (k + 1)
) * np.cos(zz * (l_orient + 1))
tmp[:, :, :, 0, k * 2 + 2, l_orient * 2 + 2] = np.sin(
yy * (k + 1)
) * np.sin(zz * (l_orient + 1))
for m in range(nharm):
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 1] = (
np.cos(xx * (k + 1))
* np.cos(yy * (l_orient + 1))
* np.cos(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 1, m * 2 + 2] = (
np.cos(xx * (k + 1))
* np.cos(yy * (l_orient + 1))
* np.sin(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 1] = (
np.cos(xx * (k + 1))
* np.sin(yy * (l_orient + 1))
* np.cos(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 1, l_orient * 2 + 2, m * 2 + 2] = (
np.cos(xx * (k + 1))
* np.sin(yy * (l_orient + 1))
* np.sin(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 1] = (
np.sin(xx * (k + 1))
* np.cos(yy * (l_orient + 1))
* np.cos(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 1, m * 2 + 2] = (
np.sin(xx * (k + 1))
* np.cos(yy * (l_orient + 1))
* np.sin(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 1] = (
np.sin(xx * (k + 1))
* np.sin(yy * (l_orient + 1))
* np.cos(zz * (m + 1))
)
tmp[:, :, :, k * 2 + 2, l_orient * 2 + 2, m * 2 + 2] = (
np.sin(xx * (k + 1))
* np.sin(yy * (l_orient + 1))
* np.sin(zz * (m + 1))
)
self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(
tmp.reshape(
norient * norient * norient,
(1 + nharm * 2) * (1 + nharm * 2) * (1 + nharm * 2),
)
)
)
self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_3_orient[(norient, nharm, imaginary)],
0 * self._fft_3_orient[(norient, nharm, imaginary)],
)
else:
tmp = np.zeros(
[norient, norient, norient, (1 + nharm), (1 + nharm), (1 + nharm)]
)
for k in range(nharm + 1):
for l_orient in range(nharm + 1):
for m in range(nharm + 1):
tmp[:, :, :, k, l_orient, m] = (
np.cos(xx * k) * np.cos(yy * l_orient) * np.cos(zz * m)
)
self._fft_3_orient[(norient, nharm, imaginary)] = self.bk_cast(
self.constant(
tmp.reshape(
norient * norient * norient,
(1 + nharm) * (1 + nharm) * (1 + nharm),
)
)
)
self._fft_3_orient_C[(norient, nharm, imaginary)] = self.bk_complex(
self._fft_3_orient[(norient, nharm, imaginary)],
0 * self._fft_3_orient[(norient, nharm, imaginary)],
)
# ---------------------------------------------−---------
# -- BACKEND DEFINITION --
# ---------------------------------------------−---------
[docs]
def bk_SparseTensor(self, indice, w, dense_shape=[]):
if self.BACKEND == self.TENSORFLOW:
return self.backend.SparseTensor(indice, w, dense_shape=dense_shape)
if self.BACKEND == self.TORCH:
return (
self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
.to_sparse_csr()
.to(self.torch_device)
)
if self.BACKEND == self.NUMPY:
return self.scipy.sparse.coo_matrix(
(w, (indice[:, 0], indice[:, 1])), shape=dense_shape
)
[docs]
def bk_stack(self, list, axis=0):
if self.BACKEND == self.TENSORFLOW:
return self.backend.stack(list, axis=axis)
if self.BACKEND == self.TORCH:
return self.backend.stack(list, axis=axis).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return self.backend.stack(list, axis=axis)
[docs]
def bk_sparse_dense_matmul(self, smat, mat):
if self.BACKEND == self.TENSORFLOW:
return self.backend.sparse.sparse_dense_matmul(smat, mat)
if self.BACKEND == self.TORCH:
return smat.matmul(mat)
if self.BACKEND == self.NUMPY:
return smat.dot(mat)
# for tensorflow wrapping only
[docs]
def periodic_pad(self, x, pad_height, pad_width):
"""
Applies periodic ('wrap') padding to a 4D TensorFlow tensor (N, H, W, C).
Args:
x (tf.Tensor): Input tensor with shape (batch_size, height, width, channels).
pad_height (tuple): Tuple (top, bottom) defining the vertical padding size.
pad_width (tuple): Tuple (left, right) defining the horizontal padding size.
Returns:
tf.Tensor: Tensor with periodic padding applied.
"""
# Vertical padding: take slices from bottom and top to wrap around
top_pad = x[:, -pad_height:, :, :] # Top padding from the bottom rows
bottom_pad = x[:, :pad_height, :, :] # Bottom padding from the top rows
x_padded = self.backend.concat(
[top_pad, x, bottom_pad], axis=1
) # Concatenate vertically
# Horizontal padding: take slices from right and left to wrap around
left_pad = x_padded[:, :, -pad_width:, :] # Left padding from right columns
right_pad = x_padded[:, :, :pad_width, :] # Right padding from left columns
x_padded = self.backend.concat(
[left_pad, x_padded, right_pad], axis=2
) # Concatenate horizontally
return x_padded
[docs]
def conv2d(self, x, w, strides=[1, 1, 1, 1], padding="SAME"):
if self.BACKEND == self.TENSORFLOW:
kx = w.shape[0]
ky = w.shape[1]
x_padded = self.periodic_pad(x, kx // 2, ky // 2)
return self.backend.nn.conv2d(x_padded, w, strides=strides, padding="VALID")
if self.BACKEND == self.TORCH:
import torch.nn.functional as F
lx = x.permute(0, 3, 1, 2)
wx = (
self.backend.from_numpy(w).to(self.torch_device).permute(3, 2, 0, 1)
) # from (5, 5, 1, 4) to (4, 1, 5, 5)
# Compute symmetric padding
kx, ky = w.shape[0], w.shape[1]
# Appliquer le padding
x_padded = F.pad(lx, (ky // 2, ky // 2, kx // 2, kx // 2), mode="circular")
# Appliquer la convolution
return F.conv2d(x_padded, wx, stride=1, padding=0).permute(0, 2, 3, 1)
if self.BACKEND == self.NUMPY:
res = np.zeros(
[x.shape[0], x.shape[1], x.shape[2], w.shape[3]], dtype=x.dtype
)
for k in range(w.shape[2]):
for l_orient in range(w.shape[3]):
for j in range(res.shape[0]):
tmp = self.scipy.signal.convolve2d(
x[j, :, :, k],
w[:, :, k, l_orient],
mode="same",
boundary="symm",
)
res[j, :, :, l_orient] += tmp
del tmp
return res
[docs]
def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
if self.BACKEND == self.TENSORFLOW:
kx = w.shape[0]
paddings = self.backend.constant([[0, 0], [kx // 2, kx // 2], [0, 0]])
tmp = self.backend.pad(x, paddings, "SYMMETRIC")
return self.backend.nn.conv1d(tmp, w, stride=strides, padding="VALID")
# to be written!!!
if self.BACKEND == self.TORCH:
return x
if self.BACKEND == self.NUMPY:
res = np.zeros([x.shape[0], x.shape[1], w.shape[2]], dtype=x.dtype)
for k in range(w.shape[2]):
for j in range(res.shape[0]):
tmp = self.scipy.signal.convolve1d(
x[j, :, k], w[:, k], mode="same", boundary="symm"
)
res[j, :, :] += tmp
del tmp
return res
[docs]
def bk_threshold(self, x, threshold, greater=True):
if self.BACKEND == self.TENSORFLOW:
return self.backend.cast(x > threshold, x.dtype) * x
if self.BACKEND == self.TORCH:
x.to(x.dtype)
return (x > threshold) * x
# return(self.backend.cast(x>threshold,x.dtype)*x)
if self.BACKEND == self.NUMPY:
return (x > threshold) * x
[docs]
def bk_maximum(self, x1, x2):
if self.BACKEND == self.TENSORFLOW:
return self.backend.maximum(x1, x2)
if self.BACKEND == self.TORCH:
return self.backend.maximum(x1, x2)
if self.BACKEND == self.NUMPY:
return x1 * (x1 > x2) + x2 * (x2 > x1)
[docs]
def bk_device(self, device_name):
return self.backend.device(device_name)
[docs]
def bk_ones(self, shape, dtype=None):
if dtype is None:
dtype = self.all_type
if self.BACKEND == self.TORCH:
return self.bk_cast(np.ones(shape))
return self.backend.ones(shape, dtype=dtype)
[docs]
def bk_conv1d(self, x, w):
if self.BACKEND == self.TENSORFLOW:
return self.backend.nn.conv1d(x, w, stride=[1, 1, 1], padding="SAME")
if self.BACKEND == self.TORCH:
# Torch not yet done !!!
return self.backend.nn.conv1d(x, w, stride=1, padding="SAME")
if self.BACKEND == self.NUMPY:
res = np.zeros([x.shape[0], x.shape[1], w.shape[1]], dtype=x.dtype)
for k in range(w.shape[1]):
for l_orient in range(w.shape[2]):
res[:, :, l_orient] += self.scipy.ndimage.convolve1d(
x[:, :, k], w[:, k, l_orient], axis=1, mode="constant", cval=0.0
)
return res
[docs]
def bk_flattenR(self, x):
if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
if self.bk_is_complex(x):
rr = self.backend.reshape(
self.bk_real(x), [np.prod(np.array(list(x.shape)))]
)
ii = self.backend.reshape(
self.bk_imag(x), [np.prod(np.array(list(x.shape)))]
)
return self.bk_concat([rr, ii], axis=0)
else:
return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
if self.BACKEND == self.NUMPY:
if self.bk_is_complex(x):
return np.concatenate([x.real.flatten(), x.imag.flatten()], 0)
else:
return x.flatten()
[docs]
def bk_flatten(self, x):
if self.BACKEND == self.TENSORFLOW:
return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
if self.BACKEND == self.TORCH:
return self.backend.reshape(x, [np.prod(np.array(list(x.shape)))])
if self.BACKEND == self.NUMPY:
return x.flatten()
[docs]
def bk_resize_image(self, x, shape):
if self.BACKEND == self.TENSORFLOW:
return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
if self.BACKEND == self.TORCH:
tmp = self.backend.nn.functional.interpolate(
x.permute(0, 3, 1, 2), size=shape, mode="bilinear", align_corners=False
)
return self.bk_cast(tmp.permute(0, 2, 3, 1))
if self.BACKEND == self.NUMPY:
return self.bk_cast(self.backend.image.resize(x, shape, method="bilinear"))
[docs]
def bk_L1(self, x):
if x.dtype == self.all_cbk_type:
xr = self.bk_real(x)
xi = self.bk_imag(x)
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
# return r
i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
if self.BACKEND == self.TORCH:
return r
else:
return self.bk_complex(r, i)
else:
return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
[docs]
def bk_square_comp(self, x):
if x.dtype == self.all_cbk_type:
xr = self.bk_real(x)
xi = self.bk_imag(x)
r = xr * xr
i = xi * xi
return self.bk_complex(r, i)
else:
return x * x
[docs]
def bk_reduce_sum(self, data, axis=None):
if axis is None:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_sum(data)
if self.BACKEND == self.TORCH:
return self.backend.sum(data)
if self.BACKEND == self.NUMPY:
return np.sum(data)
else:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_sum(data, axis=axis)
if self.BACKEND == self.TORCH:
return self.backend.sum(data, axis)
if self.BACKEND == self.NUMPY:
return np.sum(data, axis)
# ---------------------------------------------−---------
# return a tensor size
[docs]
def bk_size(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.size(data)
if self.BACKEND == self.TORCH:
return data.numel()
if self.BACKEND == self.NUMPY:
return data.size
# ---------------------------------------------−---------
[docs]
def iso_mean(self, x, use_2D=False):
shape = list(x.shape)
i_orient = 2
if use_2D:
i_orient = 3
norient = shape[i_orient]
if len(shape) == i_orient + 1:
return self.bk_reduce_mean(x, -1)
if norient not in self._iso_orient:
self.calc_iso_orient(norient)
if self.bk_is_complex(x):
lmat = self._iso_orient_C[norient]
else:
lmat = self._iso_orient[norient]
oshape = shape[0]
for k in range(1, len(shape) - 2):
oshape *= shape[k]
oshape2 = [shape[k] for k in range(0, len(shape) - 1)]
return self.bk_reshape(
self.backend.matmul(self.bk_reshape(x, [oshape, norient * norient]), lmat),
oshape2,
)
[docs]
def fft_ang(self, x, nharm=1, imaginary=False, use_2D=False):
shape = list(x.shape)
i_orient = 2
if use_2D:
i_orient = 3
norient = shape[i_orient]
nout = 1 + nharm
oshape_1 = shape[0]
for k in range(1, i_orient):
oshape_1 *= shape[k]
oshape_2 = norient
for k in range(i_orient, len(shape) - 1):
oshape_2 *= shape[k]
oshape = [oshape_1, oshape_2]
if imaginary:
nout = 1 + nharm * 2
oshape2 = [shape[k] for k in range(0, i_orient)] + [
nout for k in range(i_orient, len(shape))
]
if (norient, nharm) not in self._fft_1_orient:
self.calc_fft_orient(norient, nharm, imaginary)
if len(shape) == i_orient + 1:
if self.bk_is_complex(x):
lmat = self._fft_1_orient_C[(norient, nharm, imaginary)]
else:
lmat = self._fft_1_orient[(norient, nharm, imaginary)]
if len(shape) == i_orient + 2:
if self.bk_is_complex(x):
lmat = self._fft_2_orient_C[(norient, nharm, imaginary)]
else:
lmat = self._fft_2_orient[(norient, nharm, imaginary)]
if len(shape) == i_orient + 3:
if self.bk_is_complex(x):
lmat = self._fft_3_orient_C[(norient, nharm, imaginary)]
else:
lmat = self._fft_3_orient[(norient, nharm, imaginary)]
return self.bk_reshape(
self.backend.matmul(self.bk_reshape(x, oshape), lmat), oshape2
)
[docs]
def constant(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.constant(data)
return data
[docs]
def bk_reduce_mean(self, data, axis=None):
if axis is None:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_mean(data)
if self.BACKEND == self.TORCH:
return self.backend.mean(data)
if self.BACKEND == self.NUMPY:
return np.mean(data)
else:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_mean(data, axis=axis)
if self.BACKEND == self.TORCH:
return self.backend.mean(data, axis)
if self.BACKEND == self.NUMPY:
return np.mean(data, axis)
[docs]
def bk_reduce_min(self, data, axis=None):
if axis is None:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_min(data)
if self.BACKEND == self.TORCH:
return self.backend.min(data)
if self.BACKEND == self.NUMPY:
return np.min(data)
else:
if self.BACKEND == self.TENSORFLOW:
return self.backend.reduce_min(data, axis=axis)
if self.BACKEND == self.TORCH:
return self.backend.min(data, axis)
if self.BACKEND == self.NUMPY:
return np.min(data, axis)
[docs]
def bk_random_seed(self, value):
if self.BACKEND == self.TENSORFLOW:
return self.backend.random.set_seed(value)
if self.BACKEND == self.TORCH:
return self.backend.random.set_seed(value)
if self.BACKEND == self.NUMPY:
return np.random.seed(value)
[docs]
def bk_reduce_std(self, data, axis=None):
if axis is None:
if self.BACKEND == self.TENSORFLOW:
r = self.backend.math.reduce_std(data)
if self.BACKEND == self.TORCH:
r = self.backend.std(data)
if self.BACKEND == self.NUMPY:
r = np.std(data)
return self.bk_complex(r, 0 * r)
else:
if self.BACKEND == self.TENSORFLOW:
r = self.backend.math.reduce_std(data, axis=axis)
if self.BACKEND == self.TORCH:
r = self.backend.std(data, axis)
if self.BACKEND == self.NUMPY:
r = np.std(data, axis)
if self.bk_is_complex(data):
return self.bk_complex(r, 0 * r)
else:
return r
[docs]
def bk_sqrt(self, data):
return self.backend.sqrt(self.backend.abs(data))
[docs]
def bk_abs(self, data):
return self.backend.abs(data)
[docs]
def bk_is_complex(self, data):
if self.BACKEND == self.TENSORFLOW:
if isinstance(data, np.ndarray):
return data.dtype == "complex64" or data.dtype == "complex128"
return data.dtype.is_complex
if self.BACKEND == self.TORCH:
if isinstance(data, np.ndarray):
return data.dtype == "complex64" or data.dtype == "complex128"
return data.dtype.is_complex
if self.BACKEND == self.NUMPY:
return data.dtype == "complex64" or data.dtype == "complex128"
[docs]
def bk_distcomp(self, data):
if self.bk_is_complex(data):
res = self.bk_square(self.bk_real(data)) + self.bk_square(
self.bk_imag(data)
)
return res
else:
return self.bk_square(data)
[docs]
def bk_norm(self, data):
if self.bk_is_complex(data):
res = self.bk_square(self.bk_real(data)) + self.bk_square(
self.bk_imag(data)
)
return self.bk_sqrt(res)
else:
return self.bk_abs(data)
[docs]
def bk_square(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.square(data)
if self.BACKEND == self.TORCH:
return self.backend.square(data)
if self.BACKEND == self.NUMPY:
return data * data
[docs]
def bk_log(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.math.log(data)
if self.BACKEND == self.TORCH:
return self.backend.log(data)
if self.BACKEND == self.NUMPY:
return np.log(data)
[docs]
def bk_matmul(self, a, b):
if self.BACKEND == self.TENSORFLOW:
return self.backend.matmul(a, b)
if self.BACKEND == self.TORCH:
return self.backend.matmul(a, b)
if self.BACKEND == self.NUMPY:
return np.dot(a, b)
[docs]
def bk_tensor(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.constant(data)
if self.BACKEND == self.TORCH:
return self.backend.constant(data).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return data
[docs]
def bk_shape_tensor(self, shape):
if self.BACKEND == self.TENSORFLOW:
return self.backend.tensor(shape=shape)
if self.BACKEND == self.TORCH:
return self.backend.tensor(shape=shape).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return np.zeros(shape)
[docs]
def bk_complex(self, real, imag):
if self.BACKEND == self.TENSORFLOW:
return self.backend.dtypes.complex(real, imag)
if self.BACKEND == self.TORCH:
return self.backend.complex(real, imag).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return real + 1j * imag
[docs]
def bk_exp(self, data):
return self.backend.exp(data)
[docs]
def bk_min(self, data):
return self.backend.reduce_min(data)
[docs]
def bk_argmin(self, data):
return self.backend.argmin(data)
[docs]
def bk_tanh(self, data):
return self.backend.math.tanh(data)
[docs]
def bk_max(self, data):
return self.backend.reduce_max(data)
[docs]
def bk_argmax(self, data):
return self.backend.argmax(data)
[docs]
def bk_reshape(self, data, shape):
if self.BACKEND == self.TORCH:
if isinstance(data, np.ndarray):
return data.reshape(shape)
return data.view(shape)
return self.backend.reshape(data, shape)
[docs]
def bk_repeat(self, data, nn, axis=0):
return self.backend.repeat(data, nn, axis=axis)
[docs]
def bk_tile(self, data, nn, axis=0):
if self.BACKEND == self.TENSORFLOW:
return self.backend.tile(data, [nn])
return self.backend.tile(data, nn)
[docs]
def bk_roll(self, data, nn, axis=0):
return self.backend.roll(data, nn, axis=axis)
[docs]
def bk_expand_dims(self, data, axis=0):
if self.BACKEND == self.TENSORFLOW:
return self.backend.expand_dims(data, axis=axis)
if self.BACKEND == self.TORCH:
if isinstance(data, np.ndarray):
data = self.backend.from_numpy(data)
return self.backend.unsqueeze(data, axis)
if self.BACKEND == self.NUMPY:
return np.expand_dims(data, axis)
[docs]
def bk_transpose(self, data, thelist):
if self.BACKEND == self.TENSORFLOW:
return self.backend.transpose(data, thelist)
if self.BACKEND == self.TORCH:
return self.backend.transpose(data, thelist)
if self.BACKEND == self.NUMPY:
return np.transpose(data, thelist)
[docs]
def bk_concat(self, data, axis=None):
if self.BACKEND == self.TENSORFLOW or self.BACKEND == self.TORCH:
if axis is None:
if data[0].dtype == self.all_cbk_type:
ndata = len(data)
xr = self.backend.concat(
[self.bk_real(data[k]) for k in range(ndata)]
)
xi = self.backend.concat(
[self.bk_imag(data[k]) for k in range(ndata)]
)
return self.bk_complex(xr, xi)
else:
return self.backend.concat(data)
else:
if data[0].dtype == self.all_cbk_type:
ndata = len(data)
xr = self.backend.concat(
[self.bk_real(data[k]) for k in range(ndata)], axis=axis
)
xi = self.backend.concat(
[self.bk_imag(data[k]) for k in range(ndata)], axis=axis
)
return self.bk_complex(xr, xi)
else:
return self.backend.concat(data, axis=axis)
else:
if axis is None:
return np.concatenate(data, axis=0)
else:
return np.concatenate(data, axis=axis)
[docs]
def bk_zeros(self, shape, dtype=None):
if self.BACKEND == self.TENSORFLOW:
return self.backend.zeros(shape, dtype=dtype)
if self.BACKEND == self.TORCH:
return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return np.zeros(shape, dtype=dtype)
[docs]
def bk_gather(self, data, idx):
if self.BACKEND == self.TENSORFLOW:
return self.backend.gather(data, idx)
if self.BACKEND == self.TORCH:
return data[idx]
if self.BACKEND == self.NUMPY:
return data[idx]
[docs]
def bk_reverse(self, data, axis=0):
if self.BACKEND == self.TENSORFLOW:
return self.backend.reverse(data, axis=[axis])
if self.BACKEND == self.TORCH:
return self.backend.reverse(data, axis=axis)
if self.BACKEND == self.NUMPY:
return np.reverse(data, axis=axis)
[docs]
def bk_fft(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.signal.fft(data)
if self.BACKEND == self.TORCH:
return self.backend.fft(data)
if self.BACKEND == self.NUMPY:
return self.backend.fft.fft(data)
[docs]
def bk_fftn(self, data, dim=None):
if self.BACKEND == self.TENSORFLOW:
# Equivalent of torch.fft.fftn(x, dim=dims) in TensorFlow
x = self.bk_complex(data, 0 * data)
return self.backend.signal.fftnd(
x, fft_length=tuple(x.shape[d] for d in dim), axes=dim
)
if self.BACKEND == self.TORCH:
return self.backend.fft.fftn(data, dim=dim)
if self.BACKEND == self.NUMPY:
return self.backend.fft.fftn(data)
[docs]
def bk_ifftn(self, data, dim=None, norm=None):
if self.BACKEND == self.TENSORFLOW:
return self.backend.signal.ifftnd(
data, fft_length=tuple(data.shape[d] for d in dim), axes=dim, norm=norm
)
if self.BACKEND == self.TORCH:
return self.backend.fft.ifftn(data, dim=dim, norm=norm)
if self.BACKEND == self.NUMPY:
return self.backend.fft.ifftn(data)
[docs]
def bk_rfft(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.signal.rfft(data)
if self.BACKEND == self.TORCH:
return self.backend.rfft(data)
if self.BACKEND == self.NUMPY:
return self.backend.fft.rfft(data)
[docs]
def bk_irfft(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.signal.irfft(data)
if self.BACKEND == self.TORCH:
return self.backend.irfft(data)
if self.BACKEND == self.NUMPY:
return self.backend.fft.irfft(data)
[docs]
def bk_conjugate(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.math.conj(data)
if self.BACKEND == self.TORCH:
return self.backend.conj(data)
if self.BACKEND == self.NUMPY:
return data.conjugate()
[docs]
def bk_real(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.math.real(data)
if self.BACKEND == self.TORCH:
return data.real
if self.BACKEND == self.NUMPY:
return data.real
[docs]
def bk_imag(self, data):
if self.BACKEND == self.TENSORFLOW:
return self.backend.math.imag(data)
if self.BACKEND == self.TORCH:
if data.dtype == self.all_cbk_type:
return data.imag
else:
return 0
if self.BACKEND == self.NUMPY:
return data.imag
[docs]
def bk_relu(self, x):
if self.BACKEND == self.TENSORFLOW:
if x.dtype == self.all_cbk_type:
xr = self.backend.nn.relu(self.bk_real(x))
xi = self.backend.nn.relu(self.bk_imag(x))
return self.bk_complex(xr, xi)
else:
return self.backend.nn.relu(x)
if self.BACKEND == self.TORCH:
return self.backend.relu(x)
if self.BACKEND == self.NUMPY:
return (x > 0) * x
[docs]
def bk_clip_by_value(self, x, xmin, xmax):
if isinstance(x, np.ndarray):
x = np.clip(x, xmin, xmax)
if self.BACKEND == self.TENSORFLOW:
return self.backend.clip_by_value(x, xmin, xmax)
if self.BACKEND == self.TORCH:
x = (
self.backend.tensor(x, dtype=self.backend.float32)
if not isinstance(x, self.backend.Tensor)
else x
)
xmin = (
self.backend.tensor(xmin, dtype=self.backend.float32)
if not isinstance(xmin, self.backend.Tensor)
else xmin
)
xmax = (
self.backend.tensor(xmax, dtype=self.backend.float32)
if not isinstance(xmax, self.backend.Tensor)
else xmax
)
return self.backend.clamp(x, min=xmin, max=xmax)
if self.BACKEND == self.NUMPY:
return self.backend.clip(x, xmin, xmax)
[docs]
def bk_cast(self, x):
if isinstance(x, np.float64):
if self.all_bk_type == "float32":
return np.float32(x)
else:
return x
if isinstance(x, np.float32):
if self.all_bk_type == "float64":
return np.float64(x)
else:
return x
if isinstance(x, np.complex128):
if self.all_bk_type == "float32":
return np.complex64(x)
else:
return x
if isinstance(x, np.complex64):
if self.all_bk_type == "float64":
return np.complex128(x)
else:
return x
if isinstance(x, np.int32) or isinstance(x, np.int64) or isinstance(x, int):
if self.all_bk_type == "float64":
return np.float64(x)
else:
return np.float32(x)
if self.bk_is_complex(x):
out_type = self.all_cbk_type
else:
out_type = self.all_bk_type
if self.BACKEND == self.TENSORFLOW:
return self.backend.cast(x, out_type)
if self.BACKEND == self.TORCH:
if isinstance(x, np.ndarray):
x = self.backend.from_numpy(x).to(self.torch_device)
if x.dtype.is_complex:
out_type = self.all_cbk_type
else:
out_type = self.all_bk_type
return x.type(out_type).to(self.torch_device)
if self.BACKEND == self.NUMPY:
return x.astype(out_type)
[docs]
def bk_variable(self, x):
if self.BACKEND == self.TENSORFLOW:
return self.backend.Variable(x)
return self.bk_cast(x)
[docs]
def bk_assign(self, x, y):
if self.BACKEND == self.TENSORFLOW:
x.assign(y)
x = y
[docs]
def bk_constant(self, x):
if self.BACKEND == self.TENSORFLOW:
return self.backend.constant(x)
return self.bk_cast(x)
[docs]
def to_numpy(self, x):
if isinstance(x, np.ndarray):
return x
if self.BACKEND == self.NUMPY:
return x
if self.BACKEND == self.TENSORFLOW:
return x.numpy()
if self.BACKEND == self.TORCH:
return x.cpu().numpy()