【AI】Python3使用TF-Slim进行图像分类
2019-04-15 14:21发布
生成海报
机器环境
- win10
- python3.6
- tensorflow==1.7.0
Github地址
准备图片数据
- 准备好自定义的图片数据
- 放到 data_prepare/pic/train 和 data_prepare/pic/validation 中
- 自己建立分类文件夹,文件夹名为分类标签名
将图片数据转换成TF-Record格式文件
python data_convert.py -t pic/
--train-shards 2
--validation-shards 2
--num-threads 2
--dataset-name satellite
- 会生成4个tf-record文件和1个label文件
将转换生成的5个文件复制到 slimsatellitedata 下
修改 slimdatasetssatellite.py 文件
- FILE_PATTERN = ‘satellite%s_*.tfrecord’ (tf-record文件名格式)
- SPLITS_TO_SIZES = {‘train’: 16, ‘validation’: 4} (训练集和测试集文件总数)
- _NUM_CLASSES = 2 (分类类目总数)
- ‘image/format’: tf.FixedLenFeature((), tf.string, default_value=’jpg’) (图片格式,这里是jpg)
下载预训练模型Inception V3
在 slim/ 文件夹下执行如下命令,进行训练:
python train_image_classifier.py
--train_dir=satellite/train_dir
--dataset_name=satellite
--dataset_split_name=train
--dataset_dir=satellite/data
--model_name=inception_v3
--checkpoint_path=satellite/pretrained/inception_v3.ckpt
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--max_number_of_steps=100000
--batch_size=32
--learning_rate=0.001
--learning_rate_decay_type=fixed
--save_interval_secs=300
--save_summaries_secs=2
--log_every_n_steps=10
--optimizer=rmsprop
--weight_decay=0.00004
在 slim/ 文件夹下执行如下命令,进行模型能力评估:
python eval_image_classifier.py
--checkpoint_path=satellite/train_dir
--eval_dir=satellite/eval_dir
--dataset_name=satellite
--dataset_split_name=validation
--dataset_dir=satellite/data
--model_name=inception_v3
导出训练好的模型
python export_inference_graph.py
--alsologtostderr
--model_name=inception_v3
--output_file=satellite/inception_v3_inf_graph.pb
--dataset_name satellite
- 在 项目根目录 执行如下命令(需将5271改成train_dir中保存的实际的模型训练步数)
python freeze_graph.py
--input_graph slim/satellite/inception_v3_inf_graph.pb
--input_checkpoint slim/satellite/train_dir/model.ckpt-5271
--input_binary true
--output_node_names InceptionV3/Predictions/Reshape_1
--output_graph slim/satellite/frozen_graph.pb
对单张图片进行预测
python classify_image_inception_v3.py
--model_path slim/satellite/frozen_graph.pb
--label_path data_prepare/pic/label.txt
--image_file test_image.jpg
打开微信“扫一扫”,打开网页后点击屏幕右上角分享按钮