-
Notifications
You must be signed in to change notification settings - Fork 161
Description
测试命令如下:
easy_transfer_app --mode=train --inputSchema=query:str:1,doc:str:1,label:str:1 --inputTable=./train_lcqmc.csv,.dev_lcqmc.csv --firstSequence=query --secondSequence=doc --labelName=label --labelEnumerateValues=0,1 --batchSize=32 --numEpochs=1 --optimizerType=adam --learningRate=0.001 --modelName=text_match_hcnn --checkpointDir=./hcnn_match_models --advancedParameters='first_sequence_length=40 second_sequence_length=40 pretrain_word_embedding_name_or_path=./sgns.zhihu.char.300.bin fix_embedding=true max_vocab_size=30000 embedding_size=300 hidden_size=300'
报错信息如下:
INFO:tensorflow:Initialize word embedding from pretrained
Traceback (most recent call last):
File "/usr/local/anaconda3/envs/tf12.3/bin/easy_transfer_app", line 8, in
sys.exit(main())
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo_cli.py", line 99, in main
app.run()
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/app_utils.py", line 168, in wrapper
func(*args, **kw)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/base.py", line 44, in run
getattr(self, self.config.mode.replace("_on_the_fly", ""))()
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/base.py", line 113, in train_and_evaluate
self.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/engines/model.py", line 608, in run_train_and_evaluate
eval_spec=eval_spec)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 471, in train_and_evaluate
return executor.run()
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 610, in run
return self.run_local()
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 711, in run_local
saving_listeners=saving_listeners)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1237, in _train_model_default
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/engines/model.py", line 530, in model_fn
logits, labels = self.build_logits(features, mode=mode)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/app_zoo/text_match.py", line 618, in build_logits
filter_size=self.config.filter_size)([a_embeds, b_embeds, text_a_masks, text_b_masks])
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 374, in call
outputs = super(Layer, self).call(inputs, *args, **kwargs)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 757, in call
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/easytransfer/layers/cnn.py", line 276, in call
(a_length / 4 / 3 / 2) * (b_length / 4 / 3 / 2)])
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6482, in reshape
"Reshape", tensor=tensor, shape=shape, name=name)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 609, in _apply_op_helper
param_name=input_name)
File "/usr/local/anaconda3/envs/tf12.3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'shape' has DataType float32 not in list of allowed values: int32, int64
以上基于安装官方给出的版本tf1.12.3