Skip to content Skip to sidebar Skip to footer

Build Tensorflow Dataset Iterator That Produce Batches With Special Structure

As I mentioned in the title I need batches with special structure: 1111 5555 2222 Each digit represent feature-vector. So there are N=4 vectors of each classes {1,2,5} (M=3) and b

Solution 1:

If you have the list of files ordered by class, you can interleave the datasets:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)

dataset = tf.data.Dataset.from_tensor_slices(record_files)
# Consider tf.contrib.data.parallel_interleave for parallelization
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=M, block_length=N)
# Consider passing num_parallel_calls or using tf.contrib.data.map_and_batch for performance
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

EDIT:

If you need also shuffling you can add it in the interleaving step:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda record_file: tf.data.TFRecordDataset(record_file).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=M, block_length=N)
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

NOTE: Both interleave and batch will produce "partial" outputs if there are no more remaining elements (see docs). So you would have to take special care if it is important for you that every batch has the same shape and structure. As for batching, you can use tf.contrib.data.batch_and_drop_remainder, but as far as I know there is not a similar alternative for interleaving, so you would either have to make sure that all of your files have the same number of examples or just add repeat to the interleaving transformation.

EDIT 2:

I got a proof of concept of something like what I think you want:

import tensorflow as tf

NUM_EXAMPLES = 12
NUM_CLASSES = 9
records = [[str(i)] * NUM_EXAMPLES for i in range(NUM_CLASSES)]
M = 3
N = 4

dataset = tf.data.Dataset.from_tensor_slices(records)
dataset = dataset.interleave(tf.data.Dataset.from_tensor_slices,
                             cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N,)))
batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            b = sess.run(batch)
            print(b''.join(b).decode())
        except tf.errors.OutOfRangeError: break

Output:

888866663333
555544447777
222200001111
222288887777
666655553333
000044441111
888822225555
666600004444
777733331111

The equivalent with record files would be something like this (assuming records are one-dimensional vectors):

import tensorflow as tf

NUM_CLASSES = 9
record_files = ['class{}.tfrecord'.format(i) for i in range(NUM_CLASSES)]
M = 3
N = 4
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda file_name: tf.data.TFRecordDataset(file_name).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N, -1))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N, -1)))

This works by reading N elements of every class each time and shuffling and splitting the resulting block. It assumes that the number of classes is divisible by M and that all the files have the same number of records.

Post a Comment for "Build Tensorflow Dataset Iterator That Produce Batches With Special Structure"