facenet加载pretrained_model报错问题解决
问题
facenet github地址:
https://github.com/davidsandberg/facenet
使用facenet自己训练模型中断后,继续运行程序:
// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了
报错。
注:models/20190303-192600是上次自己训练保存模型的文件夹
原因
Facenet程序调用了tf.train.Saver类,加载预训练模型时使用了Saver.restore(sess, pretrained_model)方法,该方法传入的第二个变量应该为models/20190303-192600/model-20190303-192600.ckpt-6的字符串。
对比models/20190303-192600文件夹下的内容,发现正确输入的参数并不指向预训练模型文件夹下任何一个文件。
models/20190303-192600文件夹下的内容:
zzd@zzd-ubuntu-K80:20190303-192600$ ls
checkpoint
model-20190303-192600.ckpt-4.data-00000-of-00001
model-20190303-192600.ckpt-4.index
model-20190303-192600.ckpt-5.data-00000-of-00001
model-20190303-192600.ckpt-5.index
model-20190303-192600.ckpt-6.data-00000-of-00001
model-20190303-192600.ckpt-6.index
model-20190303-192600.meta
解决办法
方法一:
继续运行程序时输入正确的参数即可:
// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/model-20190303-192600.ckpt-6 --其他参数不写了
方法二:
修改:src/train_softmax.py 在第201行后添加一行并修改原202行,修改后为:
200 if pretrained_model:
201 print('Restoring pretrained model: %s' % pretrained_model)
202 ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrained_model))
203 saver.restore(sess, ckpt.model_checkpoint_path))
该方法能实现继续运行程序时,只需输入预训练模型所在文件夹的路径:
// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了
结束。