Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWT API key authentication to YOLOv4 #65

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tt-metal-yolov4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This implementation supports YOLOv4 execution on Grayskull and Worhmole.

## Table of Contents
- [Run server](#run-server)
- [JWT_TOKEN Authorization](#jwt_token-authorization)
- [Development](#development)
- [Tests](#tests)

Expand All @@ -19,6 +20,15 @@ docker compose --env-file tt-metal-yolov4/.env.default -f tt-metal-yolov4/docker
This will start the default Docker container with the entrypoint command set to `server/run_uvicorn.sh`. The next section describes how to override the container's default command with an interractive shell via `bash`.


### JWT_TOKEN Authorization

To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example:
```bash
export JWT_SECRET=<your-secure-secret>
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')"
```


## Development
Inside the container you can then start the server with:
```bash
Expand Down
1 change: 1 addition & 0 deletions tt-metal-yolov4/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pillow==10.3.0
locust==2.25.0
pytest==7.2.2
1 change: 1 addition & 0 deletions tt-metal-yolov4/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# inference server requirements
fastapi==0.85.1
uvicorn==0.19.0
pyjwt==2.7.0
python-multipart==0.0.5

-f https://download.pytorch.org/whl/cpu/torch_stable.html
59 changes: 57 additions & 2 deletions tt-metal-yolov4/server/fast_api_yolov4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
# SPDX-License-Identifier: Apache-2.0
import os
import logging
from fastapi import FastAPI, File, UploadFile
from fastapi import FastAPI, File, HTTPException, Request, status, UploadFile
from functools import wraps
from io import BytesIO
import jwt
from PIL import Image
from models.demos.yolov4.tests.yolov4_perfomant_webdemo import Yolov4Trace2CQ
import ttnn
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -199,8 +202,60 @@ def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False):
return np.array(keep)


def normalize_token(token) -> [str, str]:
"""
Note that scheme is case insensitive for the authorization header.
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization#directives
""" # noqa: E501
one_space = " "
words = token.split(one_space)
scheme = words[0].lower()
return [scheme, " ".join(words[1:])]


def read_authorization(
headers,
) -> Optional[dict]:
authorization = headers.get("authorization")
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Must provide Authorization header.",
)
[scheme, parameters] = normalize_token(authorization)
if scheme != "bearer":
user_error_msg = f"Authorization scheme was '{scheme}' instead of bearer"
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=user_error_msg
)
try:
payload = jwt.decode(parameters, os.getenv("JWT_SECRET"), algorithms=["HS256"])
if not payload:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return payload
except jwt.InvalidTokenError as exc:
user_error_msg = f"JWT payload decode error: {exc}"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=user_error_msg
)


def api_key_required(f):
"""Decorates an endpoint to require API key validation"""

@wraps(f)
async def wrapper(*args, **kwargs):
request: Request = kwargs.get("request")
_ = read_authorization(request.headers)

return await f(*args, **kwargs)

return wrapper


@app.post("/objdetection_v2")
async def objdetection_v2(file: UploadFile = File(...)):
@api_key_required
async def objdetection_v2(request: Request, file: UploadFile = File(...)):
contents = await file.read()
# Load and convert the image to RGB
image = Image.open(BytesIO(contents)).convert("RGB")
Expand Down
52 changes: 52 additions & 0 deletions tt-metal-yolov4/server/scripts/jwt_util.py
milank94 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

#!/usr/bin/env python3
import argparse
import json
import os

import sys

import jwt


def main():
parser = argparse.ArgumentParser(description="Generate signed JWT payload")
parser.add_argument("mode", type=str, help="'encode' or 'decode'")
parser.add_argument(
"payload", type=str, help="JSON string if 'encode', token if 'decode'"
)
parser.add_argument(
"--secret",
type=str,
dest="secret",
help="JWT secret if not provided as environment variable JWT_SECRET",
)
args = parser.parse_args()

try:
jwt_secret = os.environ.get("JWT_SECRET", args.secret)
except KeyError:
print("ERROR: Expected JWT_SECRET environment variable to be provided")
sys.exit(1)

try:
if args.mode == "encode":
json_payload = json.loads(args.payload)
encoded_jwt = jwt.encode(json_payload, jwt_secret, algorithm="HS256")
print(encoded_jwt)
elif args.mode == "decode":
decoded_jwt = jwt.decode(args.payload, jwt_secret, algorithms="HS256")
print(decoded_jwt)
else:
print("ERROR: Expected mode to be 'encode' or 'decode'")
sys.exit(1)
except json.decoder.JSONDecodeError:
print("ERROR: Expected payload to be a valid JSON string")
sys.exit(1)


if __name__ == "__main__":
main()
22 changes: 6 additions & 16 deletions tt-metal-yolov4/tests/locustfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,16 @@
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import io
import requests
from PIL import Image
from locust import HttpUser, task
from utils import get_auth_header, sample_file

# Save image as JPEG in-memory for load testing
# Load sample image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
pil_image = Image.open(requests.get(url, stream=True).raw)
pil_image = pil_image.resize((320, 320)) # Resize to target dimensions
buf = io.BytesIO()
pil_image.save(
buf,
format="JPEG",
)
byte_im = buf.getvalue()
file = {"file": byte_im}

# load sample file in memory
file = sample_file()


class HelloWorldUser(HttpUser):
@task
def hello_world(self):
self.client.post("/objdetection_v2", files=file)
headers = get_auth_header()
self.client.post("/objdetection_v2", files=file, headers=headers)
47 changes: 47 additions & 0 deletions tt-metal-yolov4/tests/test_inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

from http import HTTPStatus
import os
import pytest
import requests
from utils import get_auth_header, sample_file


DEPLOY_URL = "http://127.0.0.1"
SERVICE_PORT = int(os.getenv("SERVICE_PORT", 7000))
API_BASE_URL = f"{DEPLOY_URL}:{SERVICE_PORT}"
API_URL = f"{API_BASE_URL}/objdetection_v2"
HEALTH_URL = f"{API_BASE_URL}/health"


def test_valid_api_call():
# get sample image file
file = sample_file()
# make request with auth headers
headers = get_auth_header()
response = requests.post(API_URL, files=file, headers=headers)
# perform status and value checking
assert response.status_code == HTTPStatus.OK
assert isinstance(response.json(), list)


def test_invalid_api_call():
# get sample image file
file = sample_file()
# make request with INVALID auth header
headers = get_auth_header()
headers.update(Authorization="INVALID API KEY")
response = requests.post(API_URL, files=file, headers=headers)
# assert request was unauthorized
assert response.status_code == HTTPStatus.UNAUTHORIZED


@pytest.mark.skip(
reason="Not implemented, see https://github.com/tenstorrent/tt-inference-server/issues/63"
)
def test_get_health():
headers = {}
response = requests.get(HEALTH_URL, headers=headers, timeout=35)
assert response.status_code == 200
34 changes: 34 additions & 0 deletions tt-metal-yolov4/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import io
import os
from PIL import Image
import requests


def get_auth_header():
if authorization_header := os.getenv("AUTHORIZATION", None):
headers = {"Authorization": authorization_header}
return headers
else:
raise RuntimeError("AUTHORIZATION environment variable is undefined.")


# save image as JPEG in-memory
def sample_file():
# load sample image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
pil_image = Image.open(requests.get(url, stream=True).raw)
pil_image = pil_image.resize((320, 320)) # Resize to target dimensions
# convert to bytes
buf = io.BytesIO()
# format as JPEG
pil_image.save(
buf,
format="JPEG",
)
byte_im = buf.getvalue()
file = {"file": byte_im}
return file