How to Properly Combine TensorFlow's Dataset API and Keras?

Keras' fit_generator() model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds =, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)

Here's the error I get:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/", line 916, in _bootstrap_inner
  File "/home/jsaporta/anaconda3/lib/python3.6/", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./", line 25, in create_data_generator
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 895, in run
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/", line 2011, in fit_generator
    generator_output = next(output_generator)

Strangely enough, adding a line containing next(datagen) directly after where I initialize datagen causes the code to run just fine, with no errors.

Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?


There is indeed a more efficient way to use Dataset without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()]), epochs=5, verbose=2)

Several differences:

  1. Supply the tensor argument to the Input layer. Keras will read values from this tensor, and use it as the input to fit the model.
  2. Supply the target_tensors argument to Model.compile().
  3. Remember to convert both x and y into float32. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself.
  4. Batch size is specified during the construction of Dataset. Use steps_per_epoch and epochs to control when to stop model fitting.

In short, use Input(tensor=...), model.compile(target_tensors=...) and, y=None, ...) if your data are to be read from tensors.

Update June 09, 2018
  • Starting from Tensorflow 1.9, one can pass object directly into and it would act similar to fit_generator.
  • A complete example can be found on this gist.
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here     
    steps_per_epoch=len(x_train) // 128,
    verbose = 1)
  • tfdata_generator is a function that returns an iterable
def tfdata_generator(images, labels, is_training, batch_size=128):
  '''Construct a data generator using `tf.Dataset`. '''

  def map_fn(image, label):
      '''Preprocess raw data to trainable input. '''
    x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1))
    y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES)
    return x, y

  dataset =, labels))

  if is_training:
    dataset = dataset.shuffle(1000)  # depends on sample size
  dataset =
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  dataset = dataset.prefetch(

  return dataset
Old Solution:

In addition to @Yu-Yang's answer, you can also modify to become a generator for fit_generator as following

from tensorflow.contrib.learn.python.learn.datasets import mnist

data   = mnist.load_mnist()
model  = # your Keras model
model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels),
                    workers = 0 , # This is important
                    verbose = 1)

def tfdata_generator(images, labels, batch_size=128, shuffle=True,):
    def map_func(image, label):
        '''A transformation function'''
        x_train = tf.reshape(tf.cast(image, tf.float32), image_shape)
        y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes)
        return [x_train, y_train]

    dataset  =, labels))
    dataset  =
    dataset  = dataset.shuffle().batch(batch_size).repeat()
    iterator = dataset.make_one_shot_iterator()

    next_batch = iterator.get_next()
    while True:
        yield K.get_session().run(next_batch)

@Yu_Yang and @Dat-Nguyen's solutions both work fine. It is possible to make @Yu-Yang's solution support validation set during training as well, by using feedable iterators and passing the validation set's handle as the validation "data". It's a bit convoluted but it works.

You can also convert the Keras model to an Estimator, they support datasets:

estimator = tf.keras.estimator.model_to_estimator(keras_model=model,
input_name = model.layers[0]

def input_fn(dataset):
    dataset = X,y: {input_name: X}, y)
    return dataset.make_one_shot_iterator().get_next()

train_spec = tf.estimator.TrainSpec(
    input_fn=lambda: input_fn(train_set), max_steps=100)
eval_spec = tf.estimator.EvalSpec(
    input_fn=lambda: input_fn(test_set))

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

The other answers are good, however it is important to note that using from_tensor_slices directly with large numpy arrays can quickly fill up your memory as, IIRC, the values are copied into the graph as tf.constants. In my experience, this will cause a silent failure where training will eventually start but no improvement in loss etc will occur.

A better way is to use placeholders. E.g. here is my code to create a generator for images and their onehot targets:

def create_generator_tf_dataset(self, images, onehots, batch_size):
    # Get shapes
    img_size = images.shape
    img_size = (None, img_size[1], img_size[2], img_size[3])
    onehot_size = onehots.shape
    onehot_size = (None, onehot_size[1])

    # Placeholders
    images_tensor = tf.placeholder(tf.float32, shape=img_size)
    onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)

    # Dataset
    dataset =, onehots_tensor))
    # Map function (e.g. augmentation)
    if map_fn is not None:
        dataset = x, y: (map_fn(x), y),
    # Combined shuffle and infinite repeat
    dataset = dataset.apply(, None))  
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)

    # Make the iterator
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:, feed_dict={images_tensor: images, onehots_tensor: onehots})
        while True:
            inputs, labels =
            yield inputs, labels

Here is a solution if you are creating a TensorFlow Dataset using Pandas library. Note that this code will not work without tf.reshape() since for some reason the tensors coming from tf.py_func() don't have shape information. So this doesn't work with tuple. Does anybody have a workaround?

def _get_input_data_for_dataset(file_name):

     X_data = df_input.as_matrix()

     return X_data.astype('float32', copy=False)

X_dataset =
X_dataset = X_dataset.flat_map(lambda file_name:
                            tf.reshape(tf.py_func(_get_input_data_for_dataset,[file_name], tf.float32),[-1,1])))

X_dataset = X_dataset.batch(5)
X_iter = X_dataset.make_one_shot_iterator()
X_batch = X_iter.get_next()
input_X1 = Input(tensor= X_batch ,name='input_X1')

y1 = Dense(units=64, activation='relu',kernel_initializer=tf.keras.initializers.Constant(1),name='layer_FC1')(input_X1)

One important observation from my recent experience is to use tf.keras instead of the native keras. Works for me tf > 1.12.

Hope it can help others too.

Need Your Help

Elegant way to report missing values in a data.frame

r dataframe missing-data

Here's a little piece of code I wrote to report variables with missing values from a data frame. I'm trying to think of a more elegant way to do this, one that perhaps returns a data.frame, but I'm

How do you get the ESC key to close a dialog in Winforms?

winforms binding modal-dialog

Often when using software these days, the ESC key will close a dialog without persisting any changes I've made. I like that especially because even though the dialog may have a cancel button on it,...