當前位置 主頁 > 網站技術 > 代碼類 > 最大化 縮小

    使用TensorFlow-Slim進行圖像分類的實現

    欄目:代碼類 時間:2019-12-31 12:09

    參考 https://github.com/tensorflow/models/tree/master/slim

    使用TensorFlow-Slim進行圖像分類

    準備

    安裝TensorFlow

    參考 https://www.tensorflow.org/install/

    如在Ubuntu下安裝TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

    下載TF-slim圖像模型庫

    cd $WORKSPACE
    git clone https://github.com/tensorflow/models/
    
    

    準備數據

    有不少公開數據集,這里以官網提供的Flowers為例。

    官網提供了下載和轉換數據的代碼,為了理解代碼并能使用自己的數據,這里參考官方提供的代碼進行修改。

    cd $WORKSPACE/data
    wget http://download.tensorflow.org/example_images/flower_photos.tgz
    tar zxf flower_photos.tgz
    
    

    數據集文件夾結構如下:

    flower_photos
    ├── daisy
    │  ├── 100080576_f52e8ee070_n.jpg
    │  └── ...
    ├── dandelion
    ├── LICENSE.txt
    ├── roses
    ├── sunflowers
    └── tulips
    
    

    由于實際情況中我們自己的數據集并不一定把圖片按類別放在不同的文件夾里,故我們生成list.txt來表示圖片路徑與標簽的關系。

    Python代碼:

    import os
    
    class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
    data_dir = 'flower_photos/'
    output_path = 'list.txt'
    
    fd = open(output_path, 'w')
    for class_name in class_names_to_ids.keys():
      images_list = os.listdir(data_dir + class_name)
      for image_name in images_list:
        fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))
    
    fd.close()
    
    

    為了方便后期查看label標簽,也可以定義labels.txt:

    daisy
    dandelion
    roses
    sunflowers
    tulips
    
    

    隨機生成訓練集與驗證集:

    Python代碼:

    import random
    
    _NUM_VALIDATION = 350
    _RANDOM_SEED = 0
    list_path = 'list.txt'
    train_list_path = 'list_train.txt'
    val_list_path = 'list_val.txt'
    
    fd = open(list_path)
    lines = fd.readlines()
    fd.close()
    random.seed(_RANDOM_SEED)
    random.shuffle(lines)
    
    fd = open(train_list_path, 'w')
    for line in lines[_NUM_VALIDATION:]:
      fd.write(line)
    
    fd.close()
    fd = open(val_list_path, 'w')
    for line in lines[:_NUM_VALIDATION]:
      fd.write(line)
    
    fd.close()
    
    

    生成TFRecord數據:

    Python代碼:

    import sys
    sys.path.insert(0, '../models/slim/')
    from datasets import dataset_utils
    import math
    import os
    import tensorflow as tf
    
    def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
      fd = open(list_path)
      lines = [line.split() for line in fd]
      fd.close()
      num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
      with tf.Graph().as_default():
        decode_jpeg_data = tf.placeholder(dtype=tf.string)
        decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
        with tf.Session('') as sess:
          for shard_id in range(_NUM_SHARDS):
            output_path = os.path.join(output_dir,
              'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
            tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
            start_ndx = shard_id * num_per_shard
            end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
            for i in range(start_ndx, end_ndx):
              sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
                i + 1, len(lines), shard_id))
              sys.stdout.flush()
              image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
              image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
              height, width = image.shape[0], image.shape[1]
              example = dataset_utils.image_to_tfexample(
                image_data, b'jpg', height, width, int(lines[i][1]))
              tfrecord_writer.write(example.SerializeToString())
            tfrecord_writer.close()
      sys.stdout.write('\n')
      sys.stdout.flush()
    
    os.system('mkdir -p train')
    convert_dataset('list_train.txt', 'flower_photos', 'train/')
    os.system('mkdir -p val')
    convert_dataset('list_val.txt', 'flower_photos', 'val/')
    
    
    
    下一篇:沒有了
青海十一选五开奖数据