TensorFlow is one of the most popular machine learning libraries out there. It's one of the few frameworks that lets you work at both the lowest and highest abstraction of machine learning. You can do useful ML using high-level APIs that handle everything, or you can build a custom model from the ground up, controlling everything from model architecture to training loops.
You can find and run a complete working example of the code snippets on Kaggle.
What is TensorFlow Estimator?
Similar to Keras's Model API, TensorFlow Estimator is a high-level API that encapsulates training, evaluation, prediction, and serving. In other words, you don't have to write training, evaluation, or prediction loops, the estimator does it all for you. It also gives you a reusable model that you can save to or load from anywhere.
Why use the Estimator API?
Estimator is an easy-to-use API that handles a lot of complexity for you but still provides enough customization opportunities. You can design your model once and it will run seamlessly on your local machine, with or without GPUs, as well as on a multi-server distributed environment with GPUs and TPUs.
It safely runs a distributed training loop that takes care of loading data, checkpoints, and writing summaries for TensorBoard.
How does it work?
The Estimator API allows you to use pre-made as well as custom estimators. I'll go through working with pre-made estimators here and leave custom estimators for another time.
You can make any pre-made estimator work by going through the following steps.
a. Dataset input functions
These functions take in a dataset and return feature dictionaries and labels. A feature dictionary is a mapping of feature name to feature value in a dataset.
def make_input_fn(df, epochs=500, shuffle=True, batch_size=32):
df = df.copy()
labels = df.pop('AveragePrice') # extract label from dataframe
def input_function():
# create dataset from in-memory pandas dataframe
dataset = tf.data.Dataset.from_tensor_slices((dict(df), labels))
if shuffle:
dataset = dataset.shuffle(buffer_size=len(df))
dataset = dataset.batch(batch_size).repeat(epochs)
return dataset
return input_function
train_input_fn = make_input_fn(train_df)
eval_input_fn = make_input_fn(eval_df, epochs=1, shuffle=False)In the example above, make_input_fn is an input-function creator, as it returns input_function, which is the actual dataset input function.
b. Define feature columns
In the second step we create a list of all feature columns we wish to train our model on. It defines all features as tf.feature_column, which require a feature name and optionally a datatype and a preprocessing function.
DENSE_COLUMNS = ['Total_Bags', 'Small_Bags', 'Large_Bags', 'XLarge_Bags']
SPARSE_COLUMNS = ['type', 'year', 'region']
feature_columns = []
for feature in DENSE_COLUMNS:
feature_columns.append(tf.feature_column.numeric_column(feature))
for feature in SPARSE_COLUMNS:
vocab = data[feature].unique()
categorical_feature = tf.feature_column.categorical_column_with_vocabulary_list(feature, vocab)
feature_columns.append(tf.feature_column.indicator_column(categorical_feature))In this example I have a dataset with some dense (continuous) and some sparse (categorical) features. For each feature type, a tf.feature_column is created and added to the feature_columns list. Our model will then use this list to process actual data from the dataset.
c. Instantiate the estimator and train
There are many pre-made estimators available. You just pass feature columns to one, and you're good to go. If you need to perform parameter tuning, you can pass those arguments as well. To start training, just call train on your estimator and pass your dataset input function as an argument.
linear_regressor = tf.estimator.LinearRegressor(
feature_columns=feature_columns,
model_dir=output_dir
)
linear_regressor.train(train_input_fn)That's all! Your model will start training and keep saving model checkpoints and summaries in output_dir. You can also point TensorBoard to this directory to monitor training.
TensorFlow comes packaged with TensorBoard, an awesome tool for visualizing different model metrics.