事先声明,本文章大部分内容来源于理解TensorFlow的Queue,并添加个人理解。
Queue相关的概念只有三个:
Queue
是TF队列和缓存机制的实现QueueRunner
是TF中对操作Queue的线程的封装Coordinator
是TF中用来协调线程运行的工具
虽然它们经常同时出现,但这三样东西在TensorFlow里面是可以单独使用的,不妨先分开来看待。
Queue
根据实现的方式不同,分成具体的几种类型,例如:
- tf.FIFOQueue 按入列顺序出列的队列
- tf.RandomShuffleQueue 随机顺序出列的队列
- tf.PaddingFIFOQueue 以固定长度批量出列的队列
- tf.PriorityQueue 带优先级出列的队列
- … …
这些类型的Queue除了自身的性质不太一样外,创建、使用的方法基本是相同的。
以FIFOQueue为例,创建函数的参数:
1
| tf.FIFOQueue(capacity, dtypes, shapes=None, names=None ...)
|
Queue主要包含入列(enqueue)和出列(dequeue)两个操作。enqueue(入队)操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义(或称为“声明”),需要放在Session中运行才能获得真正的数值。下面是一个单独使用Queue的例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import tensorflow as tf tf.InteractiveSession()
q = tf.FIFOQueue(2, "float") init = q.enqueue_many(([0,0],))
x = q.dequeue() y = x+1 q_inc = q.enqueue([y])
init.run() q_inc.run() q_inc.run() q_inc.run() x.eval() x.eval() x.eval()
|
注意,如果一次性入列超过Queue Size的数据,enqueue操作会卡住,直到有数据(被其他线程)从队列取出。对一个已经取空的队列使用dequeue操作也会卡住,直到有新的数据(从其他线程)写入。
QueueRunner
Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程消费数据。QueueRunner就是来管理这些读写队列的线程的。
QueueRunner需要与Queue一起使用(这名字已经注定了它和Queue脱不开干系),但并不一定必须使用Coordinator。看下面这个例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import tensorflow as tf import sys q = tf.FIFOQueue(10, "float") counter = tf.Variable(0.0)
increment_op = tf.assign_add(counter, 1.0)
enqueue_op = q.enqueue(counter)
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)
sess = tf.InteractiveSession() tf.global_variables_initializer().run()
qr.create_threads(sess, start=True) for i in range(20): print (sess.run(q.dequeue()))
|
增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完20个数据后停止,但其他线程继续运行,程序不会结束。
Coordinator
Coordinator是个用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。例如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| import tensorflow as tf import threading, time
def loop(coord, id): t = 0 while not coord.should_stop(): print(id) time.sleep(1) t += 1 if (t >= 2 and id == 1): coord.request_stop()
coord = tf.train.Coordinator()
threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]
for t in threads: t.start() coord.join(threads)
|
将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop
方法,所有的线程都可以通过should_stop
方法感知并停止当前线程。
将QueueRunner和Coordinator一起使用,实际上就是封装了这个判断操作,从而使任何一个现成出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop
方法来停止所有子线程的执行。
案例
在TensorFlow中用Queue的经典模式有两种,都是配合了QueueRunner和Coordinator一起使用的。
第一种,显式的创建QueueRunner,然后调用它的create_threads
方法启动线程。例如下面这段代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import tensorflow as tf
data = 10 * np.random.randn(1000, 4) + 1
target = np.random.randint(0, 2, size=1000)
queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])
enqueue_op = queue.enqueue_many([data, target])
data_sample, label_sample = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
with tf.Session() as sess: coord = tf.train.Coordinator() enqueue_threads = qr.create_threads(sess, coord=coord, start=True) for step in range(100): if coord.should_stop(): break data_batch, label_batch = sess.run([data_sample, label_sample]) coord.request_stop() coord.join(enqueue_threads)
|
第二种,使用全局的start_queue_runners
方法启动线程。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| import tensorflow as tf
filename_queue = tf.train.string_input_producer(["data1.csv","data2.csv"]) reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename_queue)
with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for _ in range(100): features, labels = sess.run([data_batch, label_batch]) coord.request_stop() coord.join(threads)
|
在这个例子中,tf.train.string_input_produecer
会将一个隐含的QueueRunner添加到全局图中(类似的操作还有tf.train.shuffle_batch
、tf.train.slice_input_producer
等)。
由于没有显式地返回QueueRunner来用create_threads启动线程,这里使用了tf.train.start_queue_runners
方法直接启动tf.GraphKeys.QUEUE_RUNNERS
集合中的所有队列线程。
这两种方式在效果上是等效的。
参考
理解TensorFlow的Queue