z[i, j] = x[i, j, y[i, j]] for_each (i, j).
を TensorFlow で記述したい。
z = tf.gather(x, y, batch_dims = 2)
これで良さそうなものだけれど batch_dims < len(y.shape)
でないと駄目らしい。
z = tf.gather(x, y[:, :, tf.newaxis], batch_dims = 2)[:, :, 0]
これで一応動くけれど…うーん。
tf.gather should support batch_dims == rank(indices) #32158
当然ながら Issue があった。でもかなり最近。