Back to News for Developers

Connecting a web app to your PyTorch model using Amazon SageMaker

Author: Cami Williams

Hello fellow programmers, it is me, Cami Williams, your friendly neighborhood open source Developer Advocate, back at it again with a PyTorch blog post.

I have decided to take to task the challenge of deploying my PyTorch neural network (a.k.a. model), with the goal of hooking it up to a REST API so I can access it via a web application.

If you have yet to read my other blog posts about PyTorch, take a look at them here:

In these posts, I dive into what PyTorch is, how it is organized, and how to get up and running with training and testing a model. Next step: actually do something with it.

Admittedly, I come to the table with AWS knowledge and bias, so my go-to here is to check out their resources. In doing so I stumbled upon Amazon SageMaker: a fully managed service that provides every developer and data scientist with the ability to build, train, and deploy machine learning (ML) models quickly. *Italian chef kiss* Exactly what I want.

In theory, I would have my model hosted on SageMaker, create a Lambda function to take in data from the web, send it to the SageMaker instance, and send back information from the model, and then a REST API hosted through API Gateway to access that Lambda function. Here’s a semi-helpful diagram depicting the communication flow:

Before moving forward, it would be a MAJOR disservice to you unless I have future Cami tell you how much this all costs.

Future Cami: Hello PyTorch enthusiasts. The project Cami is about to build came back witgkh a very big AWS bill of $314.23, entirely due to the fact that she used an ml.m4.4xlarge running instance of SageMaker, versus medium or small. Don’t make the same mistake. If you do, AWS is really nice about experimenting mistakes, you can talk to customer support for help if something happens. Once I adjusted the instance, the bill was just shy of $30 (after a lot of testing). You have been warned.

AWS can get PRICEY and so experimenting with it is not always the best option. If you have found other options for deploying your model and attaching it to a REST API, please let the fam know by commenting on this post or writing a blog post that we can share! Otherwise, check out 7 Ways to Reduce Your AWS Bill, and remember to turn off whatever you aren’t actively using on AWS.

SO! What are we building today? We will use our MNIST handwritten numbers model from the Intro to PyTorch blog post to create a web app that detects digits of pi. A user will enter the website, draw a number on a web canvas, and if it is a digit of pi (i.e. 3.14159) according to our model, the digit will appear on the screen in the proper position. Thanks to Future Cami, again, for the following gif of the functioning web app:

I will build this app in the following order:

If this is your cup of tea, then let’s start building!

Host MNIST model on SageMaker

When I started building this, I went through the Get Started with Amazon SageMaker Notebook Instances and SDKs tutorial to understand how to deploy my model. This will be the abridged version, appealing to those who just want to plug and chug code and keep moving. If you are interested in the nitty gritty, definitely read that tutorial. Note: I am going to be working in us-east-1. Pick whichever AWS region you desire.

Navigate to Amazon SageMaker Studio and click on Notebook Instances in the left-hand menu.

Click on “Create notebook instance”, and enter in the following fields:

  • Notebook instance name: PyTorchPi (or something similar)
  • Notebook instance type: ml.m4.4xlarge ml.t2.medium Future Cami: this is where she made the billing mistake. Use medium.
  • Elastic inference: none
  • IAM Role: Create new role
  • S3 buckets you specify: Any S3 bucket (were you to push this app to production, you would probably want to change the permissions here, but for our example we will just make it completely open).
  • Create role
  • Root access: Enable
  • Encryption key: No Custom Encryption
  • Create notebook instance

Your notebook instance should show “Pending” for some time. While we wait, let’s put together the code for our model. There’s more documentation on how to train and deploy PyTorch models with SageMaker that you can also read. Rather than doing that, we can go straight to the source, a GitHub repo! This gives us an example on how we can train and use a PyTorch MNIST model on SageMaker, and get model results from it by drawing our own digits.

Cami, how did you find such a perfect example that matches exactly what we need to build this web app?

Dear reader, do not question a wonderful thing. Let’s keep going.

In this folder you will see ``input.html``, ``mnist.py`` and ``pytorch_mnist.ipynb``.

  • ``input.html`` : Code for the canvas to draw digits
  • ``mnist.py`` : Code to construct, train, and test your model with some data
  • ``pytorch_mnist.ipynb`` : A runnable notebook to train and host the model with MNIST data, and test with the ``input.html`` app.

Within SageMaker, we will host ``input.html`` and ``mnist.py``, and probably never touch them again. ``pytorch-mnist.ipynb`` is where we will interact with this code, potentially make changes, but ultimately deploy the model.

By this point, your PyTorchPi SageMaker Notebook instance should show a status of “InService”. If it does, under “Actions” select “Open Jupyter”. You should be directed to a new url: <NOTEBOOK-NAME>.notebook.<REGION>.sagemaker.aws/tree

Upload ``input.html``, ``mnist.py``, and ``pytorch_mnist.ipynb`` to the Jupyter directory. Once they are uploaded, click on ``pytorch_mnist.ipynb``.

This will open a new window with the runnable notebook. A couple quick housekeeping items:

  • Find the code block that says ``predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')``. Change it to say ``predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.t2.medium')``
  • Delete the “Cleanup” code block that says ``estimator.delete_endpoint()`` so we can eventually deploy the SageMaker endpoint created from the code (“Edit”, then “Delete Cells”).

In the top menu, select “Cell” then “Run All” and see magic happen before your eyes: your model building, training, and optimizing. Also maybe go get a snack or something because this takes a while to finish running.

Scroll through the notebook to see what exactly is going down, or directly to the bottom of the notebook to see the final output.

Pro tip/explanation: If a number appears next to the code block, that means it ran successfully and you can see the output. If it is a “*” then it has yet to run or is currently running.

Once the code is done running, you will see a frame where you can draw a number (as coded in ``input.html``. Draw a number in the box and then run the last code block. It will output the model’s prediction of what number your drawing represents.

Play around with this a bit for funzies. Once you are done, we are ready to deploy our model through this prediction endpoint.

Deploy model

Fortunately, the code we needed to run to deploy our model was already in this notebook! You can read more about how this is done under “Host”.

If you ever want to update the endpoint, you will have to re-run this code.

This aside, open a new tab, or navigate to the AWS SageMaker Studio again. In the left-hand menu, click on “Endpoints”.

You should be able to see your pytorch-training endpoint. Click on the endpoint to ensure that the creation time is recent to the last time you ran the notebook. Scroll to “Endpoint runtime settings” to also ensure your instance type is ml.t2.medium.

Assuming everything looks correct, take note of the Name and URL in Endpoint settings. This will be what we reference in our Lambda function. Ideally, we could connect this URL directly to API Gateway. That said, I want to make sure that it is protected by a Lambda function: my Lambda can parse the incoming data to ensure that it is in the correct format. If the sent-over data is incorrect, rather than checking it in the model (where it would be more costly time- and dollar-wise), I can handle that error in the Lambda.

Connect to Lambda

Now go to AWS Lambda, and click “Create function”. Select “Use a blueprint” and search for “microservice-http-endpoint-python”. This sets up a Lambda function that is intended to be a REST API endpoint using API Gateway.

Click on “Configure”, and enter in the following fields:

  • Function name: PyTorchPiFunction (or something similar)
  • Execution role: Create a new role from AWS policy templates
  • Role name: PyTorchSageMakerLambdaAPIGateway (or something similar)
  • Policy templates: Ensure “Simple microservice permissions” is selected.
  • API Gateway tigger: Create an API
  • API Type: REST API
  • Security: Open (again, were you to push this app to production, you would probably want to change the permissions here, but for our example we will just make it completely open).

Then click, “Create function”. In the window that appears, make note of your Lambda ARN in the top right corner. You should see a diagram in the Designer view of your function connected to API Gateway. If you click on API Gateway here, you will see information about the API endpoint (we will configure this later). For now, click on the PyTorchPiFunction.

In the Function Code section, you should see the code for your Lambda function. Here, we want to be able to intake data from our web app, send it to the SageMaker endpoint, and return the result of the model in SageMaker. To do so, we should create an environment variable for the SageMaker endpoint name.

Scroll down to the Environment variables section and click “Manage environment variables”. Click on “Add environment variable” and set the “Key” to ENDPOINT_NAME, and the value to your SageMaker endpoint name.

Save this environment variable. Now, let’s write some code. Back in the Function code section, delete everything present, and paste the following code:

          python
          import os
          import io
          import boto3
          import json
          import csv


          # grab environment variables
          ENDPOINT_NAME = os.environ['ENDPOINT_NAME']
          runtime= boto3.client('runtime.sagemaker')

          def lambda_handler(event, context):
              #print("Received event: " + json.dumps(event, indent=2))

              // 1
              data = json.loads(json.dumps(event))
              payload = data['data']
              print("PAYLOAD: ", payload)

              // 2
              response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME,
              ContentType='application/json',
              Body=payload)xe
              print(response)
              result = json.loads(response['Body'].read().decode()) // 3

              pred = max(result[0]) // 4
              predicted_label = result[0].index(pred) // 5

              return predicted_label
        

This code does the following:

  • Takes in image data
  • Invokes the SageMaker endpoint
  • Prints and loads the response from the model
  • Takes the maximum from the result (i.e. the label/number that has the highest probability according to the model output)
  • Sends this information back in the response.

This code is hard to test, because we would need to get the multi-dimensional interpretation of the image… so trust me for now and we will test momentarily. If you are curious what that interpretation looks like, here is an example of the data this model would intake for a drawing that resembles an “8”:

Neato, right? We will use this to test via API Gateway next, and I will show you the code how to translate an image to this multi-dimensional array. For now, Save your Lambda function.

Before moving on to the REST endpoint, we need to update the IAM permissions of this function to allow access to SageMaker. Scroll to the top and click on “Permissions”. Under “Execution role”, click on the Role name:

This should open a new tab where we can attach the SageMaker policy. Click on “Attach policies” and search for “SageMaker”. Add “AmazonSageMakerFullAccess” and click “Attach policy”.

Now, navigate to API Gateway and click on “PyTorchPiFunction-API”, or in your Lambda function under “Configuration”, click on the API Gateway tab in the Designer view, then click on the “PyTorchPiFunction-API”.

Create REST endpoint

Because we pre-configured this endpoint when we created our Lambda function, it has some things already defined for us. For sake of sanity, rather than explain this setup I am going to advise we delete and start from scratch.

Under Resources you should have the PyTorchPiFunction. Go ahead and delete this by clicking on “Actions”, then “Delete Resource”.

Once that is deleted, under “Actions” select “Create Method”. A dropdown menu should appear, select POST and click on the little gray checkmark. Doing so should bring up a pane where you can set up the endpoint.

In the “Lambda Function” section, paste in your Lambda ARN that we noted earlier. You can also input the name of the function (“PyTorchPiFunction”), but I always do the ARN just to ensure they are matched properly. Click Save.

Now we will see all the method executions of this endpoint. Click on the Actions dropdown and select “Enable CORS”. Update to the following:

  • Access-Control-Allow-Headers : 'Content-Type,X-Amz-Date,Authorization,X-Api-Key,x-requested-with'
  • Access-Control-Allow-Origin* : '*'

Click “Enable CORS” and replace existing CORS headers. If you get a dialog asking to confirm method changes, select Yes, replace existing values.

Click on the POST endpoint to see all of the Method executions. We can leave “Method Request”, “Integration Request” and “Integration Response” as is. Select “Method Response”. You should see a line item that has the HTTP status as 200. Click on the dropdown for this.

Under Response Headers for 200, add the following headers if they aren’t already present:

  • X-Requested-With
  • Access-Control-Allow-Headers
  • Access-Control-Allow-Origin
  • Access-Control-Allow-Methods

Click back to Method Execution. Now we can finally test our REST API. Click on the Test lightning bolt. You should now see the Method Test pane.

If you scroll to the bottom of this pane and click “Test”, you should get a Response Body that resembles the following:

          json
          {
            "errorMessage": "'data'",
            "errorType": "KeyError",
            "stackTrace": [
            	"  File \"/var/task/lambda_function.py\", line 16, in lambda_handler\n    payload = data['data']\n"
            ]
          }
        

If you do, this shows that we correctly can access our Lambda function! Yay! It is coming back with a response saying that it is missing our input data. In the request body, enter:

          {
          "data": "[[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]"
          }
        

Remember the image above of the 8 as a multidimensional array? This is it! Click “Test” and you should get more examples of multidimensional arrays, you can refer to the runnable notebook on SageMaker:

  • Run all cells of the notebook if you haven’t already.
  • Scroll to the bottom of the notebook
  • Draw a number
  • In the last cell block, add the line ``print([data])``
  • Copy the output multidimensional array and set “data” to this array in your API Gateway test.

If you are getting errors with any of this, you should check out the logs on AWS CloudWatch. To do so, navigate to CloudWatch on AWS, click on Log groups on the left-hand menu, and select your Lambda function name. There should be a line item for every time your API Gateway test calls your Lambda function. Common errors I had:

  • Parameter validation failed: Make sure that “data” and the multidimensional array are in quotes.
  • Expected 4-dimensional input for 4-dimensional weight 10 1 5 5, but got 3-dimensional input of size [1, 28, 28] instead: Ensure your multidimensional array has 4 [[[[ and ]]]] at the beginning and end.

Our API is working and fully connected to our model. In the Actions menu, select Deploy API. You can use either “default” or “[New Stage]”. For my own entertainment, I am going to make a new stage and call it “prod”. Once you have made your choice, select “Deploy”. This will take you to the Stage Editor, you shouldn’t have to make any changes here, but for sanity click “Save Changes”. Make a note somewhere of the Invoke URL, this will be the REST endpoint we call in our web app!

PHEW! We did it! All the crazy AWS stuff is DONE. Now we can begin our end... building the web application.

Build the web app

I don’t really want to get too in the weeds with this web app. I figure that you probably are reading this post more for the PyTorch stuff over the HTML/CSS madness. So, with that assumption, I am just going to explain the JavaScript side of things. Cool? Cool.

In the HTML, I am using a canvas, similar to what we used in the ``input.html`` we uploaded to SageMaker. There are ways to style how a user would draw on the canvas, this can take some time to play with.

Once we are able to intake the drawing on the canvas properly, we need to get its pixel information. This took some playing around. But essentially when a user clicks “down” in the canvas, I execute the following code.

      javascript
      function getMousePos(e) {
          if (!e)
              var e = event;

          if (e.offsetX) {
              mouseX = e.offsetX;
              mouseY = e.offsetY;
          } else if (e.layerX) {
              mouseX = e.layerX;
              mouseY = e.layerY;
          }

          x = Math.floor(e.offsetY * 0.05);
          y = Math.floor(e.offsetX * 0.05) + 1;
          for (var dy = 0; dy < 2; dy++){
              for (var dx = 0; dx < 2; dx++){
                  if ((x + dx < 28) && (y + dy < 28)){
                      pixels[(y+dy)+(x+dx)*28] = 1;
                  }
              }
          }
      }
        

The ``pixels`` array is then updated with 0s and 1s representing where the user had drawn and the negative space around the drawing. Once the user hits “Enter” or is ready to send their drawing, I use a function from the ``input.html`` to convert ``pixels`` to the proper multidimensional array format for the model.

          javascript
          function set_value() {
              let result = "[[["
              for (var i = 0; i < 28; i++) {
                  result += "["
                  for (var j = 0; j < 28; j++) {
                      result += pixels [i * 28 + j] || 0
                      if (j < 27) {
                          result += ", "
                      }
                  }
                  result += "]"
                  if (i < 27) {
                      result += ", "
                  }
              }
              result += "]]]"
              return result
          }
        

With the ``result`` from ``set_value()``, I call my REST API.

    async function sendImageToModel(canvas) {
      let data = set_value();
      let payload = {
          "data" : data
      }
      let response = await fetch('<REST API URL>', {
          method: 'POST',
          body: JSON.stringify(payload),
          dataType: 'json',
          headers: {
            'Content-Type': 'application/json'
          }
      });
      let myJson = await response.json();
      return myJson;
  }

        

When I have the numerical response from calling my API, I then show the digit of pi according to its next match on the screen and clear the canvas for the user to provide the next input.

Party

Huzzah! We did it! We have successfully created a web application that accesses our MNIST PyTorch model.

To those of you who have read this far, I salute you! If you just scrolled to the bottom of the post, hello! If you have any questions, please feel free to leave a comment here or to tweet at me on Twitter (@cwillycs).

Thanks for reading, and happy hacking!

To learn more about Facebook Open Source, visit our open source site, subscribe on Youtube, or follow us on Twitter and Facebook.