Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the generated openapi schema #3

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
14 changes: 12 additions & 2 deletions API/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Standard library imports
from enum import Enum
from typing import Union

# Third party imports
from fastapi import Depends, Header, HTTPException
from osm_login_python.core import Auth
from pydantic import BaseModel, Field

# Reader imports
from src.app import Users
from src.config import get_oauth_credentials

Expand Down Expand Up @@ -43,11 +46,18 @@ def get_osm_auth_user(access_token):
return user


def login_required(access_token: str = Header(...)):
def login_required(
access_token: str = Header(..., description="Access token from OSM API.")
):
return get_osm_auth_user(access_token)


def get_optional_user(access_token: str = Header(default=None)) -> AuthUser:
def get_optional_user(
access_token: str = Header(
default=None,
description="Allows a guest user to be used if the user is not authenticated.",
)
) -> AuthUser:
if access_token:
return get_osm_auth_user(access_token)
else:
Expand Down
39 changes: 29 additions & 10 deletions API/auth/routers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Standard library imports
import json

from fastapi import APIRouter, Depends, Request
# Third party imports
from fastapi import APIRouter, Depends, Query, Request
from pydantic import BaseModel

# Reader imports
from src.app import Users

from . import AuthUser, admin_required, login_required, osm_auth, staff_required

router = APIRouter(prefix="/auth", tags=["Auth"])


@router.get("/login/")
@router.get("/login")
def login_url(request: Request):
"""Generate Login URL for authentication using OAuth2 Application registered with OpenStreetMap.
Click on the download url returned to get access_token.
Expand All @@ -25,7 +28,7 @@ def login_url(request: Request):
return login_url


@router.get("/callback/")
@router.get("/callback")
def callback(request: Request):
"""Performs token exchange between OpenStreetMap and Raw Data API

Expand All @@ -42,7 +45,10 @@ def callback(request: Request):
return access_token


@router.get("/me/", response_model=AuthUser)
@router.get(
"/me",
response_model=AuthUser,
)
def my_data(user_data: AuthUser = Depends(login_required)):
"""Read the access token and provide user details from OSM user's API endpoint,
also integrated with underpass .
Expand All @@ -64,8 +70,11 @@ class User(BaseModel):


# Create user
@router.post("/users/", response_model=dict)
async def create_user(params: User, user_data: AuthUser = Depends(admin_required)):
@router.post("/users", response_model=dict)
async def create_user(
params: User,
user_data: AuthUser = Depends(admin_required),
):
"""
Creates a new user and returns the user's information.
User Role :
Expand Down Expand Up @@ -136,8 +145,14 @@ async def update_user(


# Delete user by osm_id
@router.delete("/users/{osm_id}", response_model=dict)
async def delete_user(osm_id: int, user_data: AuthUser = Depends(admin_required)):
@router.delete(
"/users/{osm_id}",
response_model=dict,
)
async def delete_user(
osm_id: int,
user_data: AuthUser = Depends(admin_required),
):
"""
Deletes a user based on the given osm_id.

Expand All @@ -155,9 +170,13 @@ async def delete_user(osm_id: int, user_data: AuthUser = Depends(admin_required)


# Get all users
@router.get("/users/", response_model=list)
@router.get("/users", response_model=list)
async def read_users(
skip: int = 0, limit: int = 10, user_data: AuthUser = Depends(staff_required)
skip: int = Query(0, description="The number of users to skip (for pagination)"),
limit: int = Query(
10, description="The maximum number of users to retrieve (for pagination)"
),
user_data: AuthUser = Depends(staff_required),
):
"""
Retrieves a list of users with optional pagination.
Expand Down
2 changes: 1 addition & 1 deletion API/custom_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
router = APIRouter(prefix="/custom", tags=["Custom Exports"])


@router.post("/snapshot/")
@router.post("/snapshot")
@limiter.limit(f"{RATE_LIMIT_PER_MIN}/minute")
@version(1)
async def process_custom_requests(
Expand Down
2 changes: 1 addition & 1 deletion API/hdx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def read_hdx_list(
return hdx_list


@router.get("/search/", response_model=List[dict])
@router.get("/search", response_model=List[dict])
@limiter.limit(f"{RATE_LIMIT_PER_MIN}/minute")
@version(1)
async def search_hdx(
Expand Down
41 changes: 32 additions & 9 deletions API/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Standard library imports
# Humanitarian OpenStreetmap Team
# 1100 13th Street NW Suite 800 Washington, D.C. 20005
# <[email protected]>
import time

# Third party imports
import psycopg2
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -26,6 +28,7 @@
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

# Reader imports
from src.config import (
ENABLE_CUSTOM_EXPORTS,
ENABLE_HDX_EXPORTS,
Expand Down Expand Up @@ -57,6 +60,7 @@
from .hdx import router as hdx_router

if SENTRY_DSN:
# Third party imports
import sentry_sdk

# only use sentry if it is specified in config blocks
Expand All @@ -71,11 +75,38 @@

if LOG_LEVEL.lower() == "debug":
# This is used for local setup for auth login
# Standard library imports
import os

os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

app = FastAPI(title="Raw Data API ", swagger_ui_parameters={"syntaxHighlight": False})
description = """
Raw Data API is a set of high-performant APIs for transforming and exporting OpenStreetMap (OSM) data in different GIS file formats.

## Auth
Enables handling authentication and authorization of users.

## Extract
Facilitates getting and checking data in the database.

## Tasks
Paths for managing task queues.
"""

app = FastAPI(
title="Raw Data API ",
swagger_ui_parameters={"syntaxHighlight": False},
description=description,
)

app.openapi = {
"info": {
"title": "Raw Data API",
"version": "1.0",
},
"security": [{"OAuth2PasswordBearer": []}],
}

app.include_router(auth_router)
app.include_router(raw_data_router)
app.include_router(tasks_router)
Expand All @@ -90,14 +121,6 @@
if USE_S3_TO_UPLOAD:
app.include_router(s3_router)

app.openapi = {
"info": {
"title": "Raw Data API",
"version": "1.0",
},
"security": [{"OAuth2PasswordBearer": []}],
}

app = VersionedFastAPI(
app, enable_latest=False, version_format="{major}", prefix_format="/v{major}"
)
Expand Down
43 changes: 35 additions & 8 deletions API/raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Third party imports
import redis
from area import area
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from fastapi_versioning import version

Expand Down Expand Up @@ -54,15 +54,19 @@
redis_client = redis.StrictRedis.from_url(CELERY_BROKER_URL)


@router.get("/status/", response_model=StatusResponse)
@router.get(
"/status",
response_model=StatusResponse,
description="Gives status about how recent the osm data is , it will give the last time that database was updated completely.",
)
@version(1)
def check_database_last_updated():
"""Gives status about how recent the osm data is , it will give the last time that database was updated completely"""
result = RawData().check_status()
return {"last_updated": result}


@router.post("/snapshot/", response_model=SnapshotResponse)
@router.post("/snapshot", response_model=SnapshotResponse)
@limiter.limit(f"{export_rate_limit}/minute")
@version(1)
def get_osm_current_snapshot_as_file(
Expand Down Expand Up @@ -464,7 +468,7 @@ def get_osm_current_snapshot_as_file(
)


@router.post("/snapshot/plain/")
@router.post("/snapshot/plain")
@version(1)
def get_osm_current_snapshot_as_plain_geojson(
request: Request,
Expand All @@ -474,10 +478,13 @@ def get_osm_current_snapshot_as_plain_geojson(
"""Generates the Plain geojson for the polygon within 30 Sqkm and returns the result right away

Args:

request (Request): _description_

params (RawDataCurrentParamsBase): Same as /snapshot excpet multiple output format options and configurations

Returns:

Featurecollection: Geojson
"""
area_m2 = area(json.loads(params.geometry.model_dump_json()))
Expand All @@ -496,14 +503,34 @@ def get_osm_current_snapshot_as_plain_geojson(
return result


@router.get("/countries/")
@router.get("/countries")
@version(1)
def get_countries(q: str = ""):
def get_countries(
q: str = Query("", description="A query string to filter the list of countries.")
):
"""Get a list of countries.

Args:
q (str, optional): A query string to filter the list of countries. Defaults to "".

Returns:
Any: The list of countries.
"""
result = RawData().get_countries_list(q)
return result


@router.get("/osm_id/")
@router.get("/osm_id")
@version(1)
def get_osm_feature(osm_id: int):
def get_osm_feature(
osm_id: int = Query(..., description="The ID of the OpenStreetMap feature.")
):
"""Get an OpenStreetMap feature by its ID.

Args:
osm_id (int): The ID of the OpenStreetMap feature.

Returns:
Any: The OpenStreetMap feature.
"""
return RawData().get_osm_feature(osm_id)
41 changes: 40 additions & 1 deletion API/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
paginator = s3.get_paginator("list_objects_v2")


@router.get("/files/")
@router.get("/files")
@limiter.limit(f"{RATE_LIMIT_PER_MIN}/minute")
@version(1)
async def list_s3_files(
Expand All @@ -42,6 +42,19 @@ async def list_s3_files(
default=False, description="Display size & date in human-readable format"
),
):
"""List files in an S3 bucket.

Args:
request (Request): The FastAPI request object.
folder (str, optional): The folder in the S3 bucket to list files from. Defaults to "/HDX".
prettify (bool, optional): If True, the size and last modified date of each file will be displayed in a human-readable format. Defaults to False.

Returns:
StreamingResponse: A streaming response containing the details of the files in the S3 bucket.

Raises:
HTTPException: If AWS credentials are not available, a 500 error is raised.
"""
bucket_name = BUCKET_NAME
folder = folder.strip("/")
prefix = f"{folder}/"
Expand Down Expand Up @@ -114,6 +127,18 @@ async def head_s3_file(
request: Request,
file_path: str = Path(..., description="The path to the file or folder in S3"),
):
"""Head request for an S3 file.

Args:
request (Request): The FastAPI request object.
file_path (str): The path to the file or folder in S3.

Returns:
Response: A response with headers including Last-Modified and Content-Length.

Raises:
HTTPException: If there is an AWS error, a 500 error is raised. If the file is not found, a 404 error is raised.
"""
bucket_name = BUCKET_NAME
encoded_file_path = quote(file_path.strip("/"))
try:
Expand Down Expand Up @@ -151,6 +176,20 @@ async def get_s3_file(
description="Whether to read and deliver the content of .json file",
),
):
"""Get an S3 file or folder.

Args:
request (Request): The FastAPI request object.
file_path (str): The path to the file or folder in S3.
expiry (int, optional): The expiry time for the presigned URL in seconds. Defaults to 3600 (1 hour).
read_meta (bool, optional): If True, reads and delivers the content of a .json file. Defaults to True.

Returns:
RedirectResponse: A redirect to the presigned URL for the S3 file or folder.

Raises:
HTTPException: If the file or folder is not found, a 404 error is raised.
"""
bucket_name = BUCKET_NAME
file_path = file_path.strip("/")
encoded_file_path = quote(file_path)
Expand Down
2 changes: 1 addition & 1 deletion API/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
router = APIRouter(prefix="/stats", tags=["Stats"])


@router.post("/polygon/")
@router.post("/polygon")
@limiter.limit(f"{POLYGON_STATISTICS_API_RATE_LIMIT}/minute")
@version(1)
async def get_polygon_stats(
Expand Down
Loading