class BatchMatrixMultiplication(BinaryOp):
def forward(self):
data_a, data_b = self.get_value()
# np.stack([a[i] @ b[i] for i in range(a.shape[0])])
return Tensor(
np.eisum('ijk, ikz -> ijz', tensor_a, tensor_b),
parents=self.parents,
is_leaf=False,
track_gradient=self.track_gradient,
parents=self.parents,
is_leaf=False,
op_name=self.__repr__(),
)
def backward(self, gradient = None):
data_a, data_b = self.get_value()
grad_np = gradient.numpy()
grad_a = np.einsum('ijk, ikz -> ijz', grad_np, np.transpose(data_b.detach().numpy(), (0, 2, 1)))
grad_b = np.einsum('ijk, ikz -> ijz', np.transpose(data_a, (0, 2, 1)), grad_np)
self._set_gradients(Tensor(grad_a), Tensor(grad_b))
def __repr__(self):
return "BatchMatrixMultiplication(BinaryOp)"