Use Cases

Running Inference With BERT Using TensorFlow Serving

BERT is a powerful natural language processing tool with a wide range of capabilities, but the size and complexity of the architecture make it challenging to implement in a production environment. To optimize it for memory-efficient, low-latency settings, the best approach is to implement it using TensorFlow Serving.

What is BERT, anyway?

When Google released BERT, it kicked off a frenzy in the natural language processing (NLP) space. BERT — short for Bidirectional Encoder Representations from Transformers — is a breakthrough NLP tool that can handle a wide range of tasks, including named entity recognition, sentiment analysis, and classification.

BERT made it possible for a neural network to understand the intricacies of language through a simple strategy known as word masking. Using this approach, words in sentences were randomly masked and the model was asked to predict what each masked word was.

During this pre-training process, the algorithm was able to learn how words combine to form sentences and how grammatical rules operate within any language. And once the algorithm was trained, it could be specialized to complete a wide range of tasks, including next sentence prediction and natural language inference.

Using BERT with TensorFlow Serving

These days, developers have a lot of options to put machine learning models into production, which requires memory-efficient, low-latency settings. Perhaps the most popular is TensorFlow Serving. It’s a high-performance system that’s written in low-level C++, which makes it ideal for production environments.

In this post, we’ll look at how BERT can be used as a language model that outputs word-level probabilities for any sentence. We’ll see how to wrap it in TensorFlow Serving, so it’s optimized for production. And we’ll look at how to serve TensorFlow models that are written using their high-level Estimator class.

Using BERT as a language model

BERT is a masked language model, or MLM — meaning that it was trained by masking words and attempting to predict them. That makes it challenging to use it as a language model, since it needs words from both before and after the masked word to generate a prediction. By contrast, with sequential language models that predict the next word in a sequence, the algorithm only needs words before the masked word.

Let’s take a look at how the authors of the base repository, bert-as-language-model, get around this.

To predict how probable a given word is in a sentence, the authors use a repeating mask-and-predict technique. With any given input sentence, a number of copies are generated equal to the number of words in the sentence — so, for instance, 14 copies would be made of a 14-word sentence.

With each copy, a different word is masked. In the first copy, just the first word would be masked. In the second copy, just the second word would be masked. And so on. All of these sentences are then submitted to the model to generate the probability of the masked word.

Google’s pretrained BERT model doesn’t function as a language model in the way we just described. To make it work, we need to add a layer on top of the final layer in the encoder stack of the architecture. This layer transforms the output from (batch_size, max_seq_length, hidden_units) to (batch_size, max_seq_length, vocab_size). To get the probability scores for each word, we run the softmax function over this transformation using the get_masked_predictions() function in the file.

We won’t get too deep into the codebase here, but it’s important to identify the inputs that are used for serving the model:

input_ids Words in the sentence transformed into their dictionary indices

List of 1s and 0s after padding to max_len

  • 1 denotes that the word exists
  • 0 signifies that it’s a padding token
segment_ids ID used for downstream applications that would typically use two sentences for input, such as natural language inference (this is syntactically needed, but not necessary for this language model)
masked_lm_positions Indicates the position of the masked word inside the original sentence
masked_lm_ids Indicates the ID of the actual word that was masked

Saving your TensorFlow model In SavedModel format

Before we can serve any TensorFlow model, we need to save it into the SavedModel format. But in this case, there’s an extra wrinkle: since we’re adding an extra layer at the top in this use case, we need to run the prediction loop so that the weights in the added layer are initialized first (fine-tuning on one’s own dataset also accomplishes this same task) — and then move on to saving. Without this step, we might get tracebacks and other unexpected behavior from the model.

To transform the final model into the SavedModel format, the Estimator class exposes an export_savedmodel function. This function uses the serving_input_receiver_fn() function, which indicates the shapes and data types of all the input tensors needed by the final, servable model.

Here’s what the serving_input_receiver_fn() function would look like:

def serving_input_rec_fn():
serving_features = {"input_ids": tf.placeholder(shape=[None, max_seq_length], dtype=tf.int32),
	"input_mask": tf.placeholder(shape=[None, max_seq_length], dtype=tf.int32),
	"segment_ids": tf.placeholder(shape=[None, max_seq_length], dtype=tf.int32),
	"masked_lm_positions": tf.placeholder(shape=[None, max_predictions_per_seq], dtype=tf.int32),
	"masked_lm_ids": tf.placeholder(shape=[None, max_predictions_per_seq], dtype=tf.int32)}
return tf.estimator.export.build_raw_serving_input_receiver_fn(features=serving_features)

Now that we’ve defined the receiver function, we want to ensure that the model is executed once and the extra layers we defined for the language model are included. To do that, add these two lines:

save_hook = tf.train.CheckpointSaverHook(output_dir, save_secs=1)
result = estimator.predict(input_fn=predict_input_fn, hooks = [save_hook])

Next, we want to save the model in SavedModel format in our working directory, using these lines of code:

estimator._export_to_tpu = False
export_path = estimator.export_savedmodel(os.getcwd(), serving_input_rec_fn())

Running inference on the served model

And with that, we’ve saved the model into the format we need for TensorFlow Serving. The next step is to host the servable model on a Docker container. We can then use it to run predictions with very low latency.

First, we need to set up a Docker container that has TensorFlow Serving as the base image, with the following command:

docker pull tensorflow/serving:1.12.0

For now, we’ll call the served model tf-serving-bert. We can use this command to spin up this model on a Docker container with tensorflow-serving as the base image:

docker run -p 8500:8500 -p 8501:8501 --mount type=bind,source=$(pwd)/exported-model,target=/models/tf-serving-bert -e MODEL_NAME=tf-serving-bert -t tensorflow/serving

In this command, $(pwd)/exported-model is the location where we saved the SavedModel. It contains the graph in .pb format and the variables folder that contains the .data-00000-of-00001 and .index files.

At this point, the model is set up. Now we just need to send a REST API request to the served model. When you’ve hosted it on your local system, the model should be running at this endpoint:


Looking for an example? Check out this sample Python script that accepts a .tsv file as its input, sends a request to the above URL, and outputs word-level probabilities for each line in the .tsv file. (Be sure to run it in the same repository that has the script. And of course, make any necessary changes to the path and other parameters in the file provided above.)

By using BERT with TensorFlow Serving, you can get results dramatically faster than with traditional methods.

This post is the first in a two-part series on how to implement heavy models like BERT in low-latency environments.  In Part 2, we’ll look at how you can take this model to the next level by applying it on a GPU.

Additional reading

Looking to learn more about using BERT with TensorFlow Serving? Check out these resources: