てがみ: qatacri at protonmail.com | 統計 | 2019

201920501

z[i, j] = x[i, j, y[i, j]] for_each (i, j).

を TensorFlow で記述したい。

z = tf.gather(x, y, batch_dims = 2)

これで良さそうなものだけれど batch_dims < len(y.shape) でないと駄目らしい。

z = tf.gather(x, y[:, :, tf.newaxis], batch_dims = 2)[:, :, 0]

これで一応動くけれど…うーん。

tf.gather should support batch_dims == rank(indices) #32158

当然ながら Issue があった。でもかなり最近。