plot_attention

def plot_attention(attention_matrix, input_sentence, predicted_sentence):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)

ax.matshow(attention_matrix, cmap='viridis') # cmap 定义的是配色

font_dict = {'fontsize': 14}

ax.set_xticklabels([''] + input_sentence, fontdict = font_dict, rotation = 90) # roation 让字体反转90度
ax.set_yticklabels([''] + predicted_sentence,fontdict = font_dict)
plt.show()