train_step() 发表于 2019-08-21 | 字数统计: 149 | 阅读时长 ≈ 1 理论好理解,实践起来,里面有很多的知识点需要理解。 train_step()方法的理解 训练的方法其实就是使用定义好的方法,和之前做测试的时候方法类似。 1234567891011121314151617181920@tf.functiondef train_step(inp, targ, encoding_hidden): loss = 0 with tf.GradientTape() as tape: encoding_outputs, encoding_hidden = encoder(inp, encoding_hidden) decoding_hidden = encoding_hidden# print("-----------------------计算Encoder-------------------------------------------------") for t in range(0, targ.shape[1] - 1): decoding_input = tf.expand_dims(targ[:, t], 1) predictions, decoding_hidden, _ = decoder(decoding_input, decoder_hidden, encoding_outputs) loss += loss_function(targ[:, t+1], predictions) # targ[:, t+1], 是每一步的预测值。# print("-----------------------计算loss-------------------------------------------------") batch_loss = loss / int(targ.shape[0]) variables = encoder.trainable_variables + decoder.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return batch_loss