てがみ: qatacri at protonmail.com | 統計 | 2019

201920500

Transfer model for language understanding | TensorFlow Core

これの accuracy の値が妙なのは、可変長文字列のマスクをしていないことと、単語単位の平均正答率の計算になっているからなので、

+def accuracy(real, pred):
+  diff = tf.equal(real, tf.math.argmax(pred, -1)) | tf.equal(real, 0)
+  return tf.where(tf.reduce_all(diff, -1), 1.0, 0.0)

-train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
-    name='train_accuracy')
+train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

@@ def train_step(inp, tar):
-  train_accuracy(tar_real, predictions)
+  train_accuracy(accuracy(tar_real, predictions))

とりあえずこんな感じで修正できる。もちろん自然言語の変換のときには小さな値になるけれど、指標としては十分機能する。