• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

TensorFlowサンプルコード


Commit MetaInfo

Revision6d18270d4073ebe203d054845675b35235aa23e8 (tree)
Time2018-01-18 20:36:56
Authorhylom <hylom@hylo...>
Commiterhylom

Log Message

add TFRecord samples

Change Summary

Incremental Difference

--- /dev/null
+++ b/tfrecords/convert_tfr.py
@@ -0,0 +1,104 @@
1+#!/usr/bin/env python
2+# -*- coding: utf-8 -*-
3+import argparse
4+import os
5+
6+import tensorflow as tf
7+
8+def main():
9+ # 引数をパースする
10+ p = argparse.ArgumentParser(description='convert images to TFRecord format')
11+ p.add_argument('dimension',
12+ type=lambda x: [int(r) for r in x.split("x")],
13+ help='image dimension. example: 10x10')
14+ p.add_argument('label',
15+ type=int,
16+ help='label.')
17+ p.add_argument('output',
18+ #type=argparse.FileType('w'),
19+ help='output file')
20+ p.add_argument('target_dir',
21+ nargs='+',
22+ help='target directory')
23+ args = p.parse_args()
24+ if len(args.dimension) != 2:
25+ raise argparse.ArgumentTypeError("dimension must be <num>x<num>")
26+
27+ # 計算グラフを構築する
28+ (width, height) = args.dimension
29+ (filepath, padded_image) = _create_comp_graph(width, height)
30+
31+ # 出力先の用意
32+ writer = tf.python_io.TFRecordWriter(args.output)
33+
34+ # セッションを実行
35+ sess = tf.Session()
36+
37+ # 指定したディレクトリ内のファイルを列挙する
38+ for dirname in args.target_dir:
39+ for filename in os.listdir(dirname):
40+ pathname = os.path.join(dirname, filename)
41+ print("process {}...".format(pathname))
42+ try:
43+ image = sess.run(padded_image, {filepath: pathname})
44+ except tf.errors.InvalidArgumentError:
45+ # 読み込みに失敗したらメッセージを出力して継続する
46+ print("{}: invalid jpeg file. ignored.".format(pathname))
47+ continue
48+
49+ features = tf.train.Features(feature={
50+ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[args.label])),
51+ 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
52+ 'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
53+ 'raw_image': tf.train.Feature(float_list=tf.train.FloatList(value=image.reshape(width*height*3))),
54+ })
55+ example = tf.train.Example(features=features)
56+ writer.write(example.SerializeToString())
57+
58+ # 終了
59+ writer.close()
60+ print("done")
61+
62+# 指定したファイルを読み出して指定したサイズにリサイズする
63+def _create_comp_graph(width, height):
64+ filepath = tf.placeholder(tf.string, name="pathname")
65+ content = tf.read_file(filepath)
66+ raw_image = tf.image.decode_jpeg(content, channels=3)
67+ # raw_imageのデータ型はuint8
68+
69+ # 画像のサイズを[height, width, channels]の形で取得
70+ shape = tf.shape(raw_image)
71+ raw_height = tf.to_float(tf.slice(shape, [0], [1]))
72+ raw_width = tf.to_float(tf.slice(shape, [1], [1]))
73+
74+ # 画像サイズと出力サイズのアスペクト比を比較し、
75+ # 画像サイズのほうが大きければ高さを、
76+ # 小さければ幅を出力サイズにそろえるよう
77+ # 拡大縮小比を求める
78+ # 幅/高さは整数(int)型データなので、適宜float型に変換する
79+ aspect_ratio = float(height) / width
80+ raw_aspect_ratio = raw_height / raw_width
81+ scale = tf.cond(tf.reduce_any(raw_aspect_ratio > [aspect_ratio]),
82+ lambda: [height] / raw_height,
83+ lambda: [width] / raw_width
84+ )
85+
86+ # 求めた比率で画像をリサイズする
87+ # resized_imageはfloat32型となる
88+ new_size = tf.to_int32(tf.concat([raw_height, raw_width], 0) * scale)
89+ resized_image = tf.image.resize_images(raw_image, new_size)
90+
91+ # 余白を0で埋めて指定したサイズにそろえる
92+ padded_image = tf.image.resize_image_with_crop_or_pad(
93+ resized_image,
94+ height,
95+ width
96+ )
97+
98+ return (filepath, padded_image)
99+
100+
101+if __name__ == "__main__":
102+ main()
103+
104+
--- /dev/null
+++ b/tfrecords/create_data.sh
@@ -0,0 +1,10 @@
1+#!/bin/sh
2+DATA_DIR=../data2
3+
4+LABEL=0
5+for i in cat dog monkey; do
6+ echo convert $i:$LABEL
7+ ./convert_tfr.py 100x100 $LABEL ${DATA_DIR}/teach_$i.tfrecord ${DATA_DIR}/$i/teach
8+ ./convert_tfr.py 100x100 $LABEL ${DATA_DIR}/test_$i.tfrecord ${DATA_DIR}/$i/test
9+ LABEL=$(expr $LABEL + 1)
10+done
--- /dev/null
+++ b/tfrecords/neural_learning.py
@@ -0,0 +1,181 @@
1+#!/usr/bin/env python
2+# -*- coding: utf-8 -*-
3+import sys
4+import tensorflow as tf
5+
6+INPUT_WIDTH = 100
7+INPUT_HEIGHT = 100
8+INPUT_CHANNELS = 3
9+
10+INPUT_SIZE = INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS
11+W1_SIZE = 200
12+OUTPUT_SIZE = 3
13+LABEL_SIZE = OUTPUT_SIZE
14+
15+TEACH_FILES = ["../data2/teach_cat.tfrecord",
16+ "../data2/teach_dog.tfrecord",
17+ "../data2/teach_monkey.tfrecord"]
18+TEST_FILES = ["../data2/test_cat.tfrecord",
19+ "../data2/test_dog.tfrecord",
20+ "../data2/test_monkey.tfrecord"]
21+
22+MODEL_NAME = "./neural_model"
23+
24+tf.set_random_seed(1111)
25+
26+# モデルを定義
27+with tf.variable_scope('model') as scope:
28+ x1 = tf.placeholder(dtype=tf.float32)
29+ y = tf.placeholder(dtype=tf.float32)
30+
31+ # 第2層
32+ W1 = tf.get_variable("W1",
33+ shape=[INPUT_SIZE, W1_SIZE],
34+ dtype=tf.float32,
35+ initializer=tf.random_normal_initializer(stddev=0.01))
36+ b1 = tf.get_variable("b1",
37+ shape=[W1_SIZE],
38+ dtype=tf.float32,
39+ initializer=tf.random_normal_initializer(stddev=0.01))
40+ x2 = tf.sigmoid(tf.matmul(x1, W1) + b1, name="x2")
41+
42+ # 第3層
43+ W2 = tf.get_variable("W2",
44+ shape=[W1_SIZE, OUTPUT_SIZE],
45+ dtype=tf.float32,
46+ initializer=tf.random_normal_initializer(stddev=0.01))
47+ b2 = tf.get_variable("b2",
48+ shape=[OUTPUT_SIZE],
49+ dtype=tf.float32,
50+ initializer=tf.random_normal_initializer(stddev=0.01))
51+ x3 = tf.nn.softmax(tf.matmul(x2, W2) + b2, name="x3")
52+
53+ # コスト関数
54+ cross_entropy = -tf.reduce_sum(y * tf.log(x3), name="cross_entropy")
55+ tf.summary.scalar('cross_entropy', cross_entropy)
56+
57+ # 正答率
58+ # 出力テンソルの中でもっとも値が大きいもののインデックスが
59+ # 正答と等しいかどうかを計算する
60+ correct = tf.equal(tf.argmax(x3,1), tf.argmax(y,1), name="correct")
61+ accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy")
62+ tf.summary.scalar('accuracy', accuracy)
63+
64+ # 最適化アルゴリズムを定義
65+ global_step = tf.Variable(0, name='global_step', trainable=False)
66+ optimizer = tf.train.GradientDescentOptimizer(1e-4, name="optimizer")
67+ minimize = optimizer.minimize(cross_entropy, global_step=global_step, name="minimize")
68+
69+ # 学習結果を保存するためのオブジェクトを用意
70+ saver = tf.train.Saver()
71+
72+# 読み込んだデータの変換用関数
73+def map_dataset(serialized):
74+ features = {
75+ 'label': tf.FixedLenFeature([], tf.int64),
76+ 'height': tf.FixedLenFeature([], tf.int64),
77+ 'width': tf.FixedLenFeature([], tf.int64),
78+ 'raw_image': tf.FixedLenFeature([INPUT_SIZE], tf.float32),
79+ }
80+ parsed = tf.parse_single_example(serialized, features)
81+
82+ # 読み込んだデータを変換する
83+ raw_label = tf.cast(parsed['label'], tf.int32)
84+ label = tf.reshape(tf.slice(tf.eye(LABEL_SIZE),
85+ [raw_label, 0],
86+ [1, LABEL_SIZE]),
87+ [LABEL_SIZE])
88+
89+ image = parsed['raw_image']
90+ return (image, label, raw_label)
91+
92+## データセットの読み込み
93+# 読み出すデータは各データ200件ずつ×3で計600件
94+dataset = tf.data.TFRecordDataset(TEACH_FILES)\
95+ .map(map_dataset)\
96+ .batch(600)
97+
98+# データにアクセスするためのイテレータを作成
99+iterator = dataset.make_one_shot_iterator()
100+item = iterator.get_next()
101+
102+# セッションの作成
103+sess = tf.Session()
104+
105+# 変数の初期化を実行する
106+sess.run(tf.global_variables_initializer())
107+
108+# 学習結果を保存したファイルが存在するかを確認し、
109+# 存在していればそれを読み出す
110+latest_filename = tf.train.latest_checkpoint("./")
111+if latest_filename:
112+ print("load saved model {}".format(latest_filename))
113+ saver.restore(sess, latest_filename)
114+
115+# サマリを取得するための処理
116+summary_op = tf.summary.merge_all()
117+summary_writer = tf.summary.FileWriter('data', graph=sess.graph)
118+
119+# 学習用データを読み出す
120+(dataset_x, dataset_y, values_y) = sess.run(item)
121+
122+
123+steps = tf.train.global_step(sess, global_step)
124+
125+if steps == 0:
126+ # 初期状態を記録
127+ xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op], {x1: dataset_x, y: dataset_y})
128+ print("CROSS ENTROPY({}): {}".format(0, xe))
129+ print(" ACCURACY({}): {}".format(0, acc))
130+ summary_writer.add_summary(summary, global_step=0)
131+
132+# 学習を開始
133+for i in range(10):
134+ for j in range(100):
135+ sess.run(minimize, {x1: dataset_x, y: dataset_y})
136+
137+ # 途中経過を取得・保存
138+ xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op], {x1: dataset_x, y: dataset_y})
139+ print("CROSS ENTROPY({}): {}".format(steps + 100 * (i+1), xe))
140+ print(" ACCURACY({}): {}".format(steps + 100 * (i+1), acc))
141+ summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step))
142+
143+# 学習終了
144+# 結果を保存する
145+save_path = saver.save(sess, MODEL_NAME, global_step=tf.train.global_step(sess, global_step))
146+print("Model saved to {}".format(save_path))
147+
148+## 結果の出力
149+
150+# 学習に使用したデータを入力した場合の
151+# 正答率を計算する
152+print("----result with teaching data----")
153+
154+print("assumed label:")
155+print(sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x}))
156+print("real label:")
157+print(sess.run(tf.argmax(y, 1), feed_dict={y: dataset_y}))
158+print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y: dataset_y}))
159+
160+
161+# テスト用データを入力した場合の
162+# 正答率を計算する
163+print("----result with test data----")
164+
165+## テスト用データセットの読み込み
166+# テストデータは50×3=150件
167+dataset2 = tf.data.TFRecordDataset(TEST_FILES)\
168+ .map(map_dataset)\
169+ .batch(150)
170+iterator2 = dataset2.make_one_shot_iterator()
171+item2 = iterator2.get_next()
172+(dataset_x, dataset_y, values_y) = sess.run(item2)
173+
174+# 正答率を出力
175+print("assumed label:")
176+print(sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x}))
177+print("real label:")
178+print(sess.run(tf.argmax(y, 1), feed_dict={y: dataset_y}))
179+print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y: dataset_y}))
180+
181+