Source code for foscat.loss_backend_torch

import torch


[docs] class loss_backend: def __init__(self, backend, curr_gpu, mpi_rank): self.bk = backend self.curr_gpu = curr_gpu self.mpi_rank = mpi_rank
[docs] def check_dense(self, data, datasz): if isinstance(data, torch.Tensor): return data """ idx=tf.cast(data.indices, tf.int32) data=tf.math.bincount(idx,weights=data.values, minlength=datasz) """ return data
# ---------------------------------------------−---------
[docs] def loss(self, x, batch, loss_function, KEEP_TRACK): operation = loss_function.scat_operator if torch.cuda.is_available(): with torch.cuda.device((operation.gpupos + self.curr_gpu) % operation.ngpu): l_x = x.clone().detach().requires_grad_(True) if KEEP_TRACK is not None: l_loss, linfo = loss_function.eval(l_x, batch, return_all=True) else: l_loss = loss_function.eval(l_x, batch) l_loss.backward() g = l_x.grad self.curr_gpu = self.curr_gpu + 1 else: l_x = x.clone().detach().requires_grad_(True) if KEEP_TRACK is not None: l_loss, linfo = loss_function.eval(l_x, batch, return_all=True) else: l_loss = loss_function.eval(l_x, batch) l_loss.backward() g = l_x.grad if KEEP_TRACK is not None: return l_loss.detach(), g, linfo else: return l_loss.detach(), g