train_step()

理论好理解,实践起来,里面有很多的知识点需要理解。

train_step()方法的理解

训练的方法其实就是使用定义好的方法,和之前做测试的时候方法类似。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

@tf.function
def train_step(inp, targ, encoding_hidden):
loss = 0
with tf.GradientTape() as tape:
encoding_outputs, encoding_hidden = encoder(inp, encoding_hidden)
decoding_hidden = encoding_hidden
# print("-----------------------计算Encoder-------------------------------------------------")
for t in range(0, targ.shape[1] - 1):
decoding_input = tf.expand_dims(targ[:, t], 1)
predictions, decoding_hidden, _ = decoder(decoding_input, decoder_hidden, encoding_outputs)

loss += loss_function(targ[:, t+1], predictions) # targ[:, t+1], 是每一步的预测值。
# print("-----------------------计算loss-------------------------------------------------")
batch_loss = loss / int(targ.shape[0])
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))

return batch_loss