tf.compat.v2.expand_dims()
可以对tensor进行维度扩展。
1 | tf.expand_dims( |
我应该在第一个维度进行扩展。
原来的维度 (2,588). —->(1,2,588)
tf 框架需要的数据类型是float32
调试以后。搭建的模型在测试的时候,可以正常的通过。之前主要是输入的shape不一样。
现在没有办法训练。经过调试以后程序在训练时一直处于decoder。肯定是这的问题。重新理解一下Decoder函数
1 | class Decoder(keras.Model):#使用子类API实现 |