## Building a neural network training framework with learn API
For simplicity, in most of the examples here we manually create sessions and we don't care about saving and loading checkpoints but this is not how we usually do things in practice. You most probably want to use the learn API to take care of session management and logging. We provide a simple but practical [framework](https://github.com/vahidk/TensorflowFramework/tree/master) for training neural networks using TensorFlow. In this item we explain how this framework works.
When experimenting with neural network models you usually have a training/test split. You want to train your model on the training set, and once in a while evaluate it on test set and compute some metrics. You also need to store the model parameters as a checkpoint, and ideally you want to be able to stop and resume training. TensorFlow's learn API is designed to make this job easier, letting us focus on developing the actual model.
The most basic way of using tf.learn API is to use tf.Estimator object directly. You need to define a model function that defines a loss function, a train op, one or a set of predictions, and optionally a set of metric ops for evaluation:
import tensorflow as tf
def model_fn(features, labels, mode, params):
predictions = ...
loss = ...
train_op = ...
metric_ops = ...
params = ...
run_config = tf.estimator.RunConfig(model_dir=FLAGS.output_dir)
estimator = tf.estimator.Estimator(
model_fn=model_fn, config=run_config, params=params)
To train the model you would then simply call Estimator.train() function while providing an input function to read the data:
and to evaluate the model, simply call Estimator.evaluate():
The input function returns two tensors (or dictionaries of tensors) providing the features and labels to be passed to the model:
features = ...
labels = ...
return features, labels
See [mnist.py](https://github.com/vahidk/TensorflowFramework/blob/master/dataset/mnist.py) for an example of how to read your data with the dataset API. To learn about various ways of reading your data in TensorFlow refer to [this item](#data).
The framework also comes with a simple convolutional network classifier in [alexnet.py](https://github.com/vahidk/TensorflowFramework/blob/master/model/alexnet.py) that includes an example model.
And that's it! This is all you need to get started with TensorFlow learn API. I recommend to have a look at the framework [source code](https://github.com/vahidk/TensorFlowFramework) and see the official python API to learn more about the learn API.