facenet加载pretrained_model报错问题解决

2019-04-14 16:20发布

facenet加载pretrained_model报错问题解决

问题

facenet github地址:https://github.com/davidsandberg/facenet
使用facenet自己训练模型中断后,继续运行程序: // 继续训练 python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了 报错。
注:models/20190303-192600是上次自己训练保存模型的文件夹

原因

Facenet程序调用了tf.train.Saver类,加载预训练模型时使用了Saver.restore(sess, pretrained_model)方法,该方法传入的第二个变量应该为models/20190303-192600/model-20190303-192600.ckpt-6的字符串。
对比models/20190303-192600文件夹下的内容,发现正确输入的参数并不指向预训练模型文件夹下任何一个文件。
models/20190303-192600文件夹下的内容: zzd@zzd-ubuntu-K80:20190303-192600$ ls checkpoint model-20190303-192600.ckpt-4.data-00000-of-00001 model-20190303-192600.ckpt-4.index model-20190303-192600.ckpt-5.data-00000-of-00001 model-20190303-192600.ckpt-5.index model-20190303-192600.ckpt-6.data-00000-of-00001 model-20190303-192600.ckpt-6.index model-20190303-192600.meta

解决办法

方法一:
继续运行程序时输入正确的参数即可: // 继续训练 python3 src/train_softmax.py --pretrained_models models/20190303-192600/model-20190303-192600.ckpt-6 --其他参数不写了 方法二:
修改:src/train_softmax.py 在第201行后添加一行并修改原202行,修改后为: 200 if pretrained_model: 201 print('Restoring pretrained model: %s' % pretrained_model) 202 ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrained_model)) 203 saver.restore(sess, ckpt.model_checkpoint_path)) 该方法能实现继续运行程序时,只需输入预训练模型所在文件夹的路径: // 继续训练 python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了 结束。