enerzyme.models.functional.gather_nd#

enerzyme.models.functional.gather_nd(params: Tensor, indices: Tensor) Tensor[source]#

The same as tf.gather_nd but batched gather is not supported yet. indices is an k-dimensional integer tensor, best thought of as a (k-1)-dimensional tensor of indices into params, where each element defines a slice of params:

output[(i_0, …, i_{k-2})] = params[indices[(i_0, …, i_{k-2})]]

Args:

params (Tensor): “n” dimensions. shape: [x_0, x_1, x_2, …, x_{n-1}] indices (Tensor): “k” dimensions. shape: [y_0,y_2,…,y_{k-2}, m]. m <= n.

Returns: gathered Tensor.

shape [y_0,y_2,…y_{k-2}] + params.shape[m:]

implemented at https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/37