From f08178920eef9fed3ee5ee9d1e0f13a6b43b8ad7 Mon Sep 17 00:00:00 2001 From: Felix Kolwa Date: Tue, 2 Jul 2024 08:58:00 +0200 Subject: [PATCH] Add pagination to /aws Add paginate method to crud.py Add optional pagination params to /aws Add necessary tests --- cid/crud.py | 58 +++++++++++++++++---- cid/main.py | 8 +-- tests/test_crud.py | 125 +++++++++++++++++++++++++++++++++++++++------ tests/test_main.py | 38 ++++++++++---- 4 files changed, 189 insertions(+), 40 deletions(-) diff --git a/cid/crud.py b/cid/crud.py index a337053..bcc5d0a 100644 --- a/cid/crud.py +++ b/cid/crud.py @@ -7,6 +7,7 @@ from packaging.version import Version from sqlalchemy import desc from sqlalchemy.orm import Session +from sqlalchemy.orm.query import Query from cid.config import CLOUD_PROVIDERS from cid.database import engine @@ -376,18 +377,22 @@ def find_aws_images( version: Optional[str] = None, name: Optional[str] = None, region: Optional[str] = None, -) -> list: - """Return all AWS images that match the given criteria. + page: int = 1, + page_size: int = 100, +) -> dict: + """Return paginated AWS images that match the given criteria. Args: - db (Session): database session - arch (Optional[str]): architecture to search - version (Optional[str]): RHEL version to search - name (Optional[str]): image name to search - region (Optional[str]): AWS region to search + db (Session): database session + arch (Optional[str]): architecture to search + version (Optional[str]): RHEL version to search + name (Optional[str]): image name to search + region (Optional[str]): AWS region to search + page (int): page number + page_size (int): number of images per page Returns: - list: list of images that match the given criteria + list: list of images that match the given criteria """ query = db.query(AwsImage).order_by(AwsImage.creationDate.desc()) @@ -400,4 +405,39 @@ def find_aws_images( if region: query = query.filter(AwsImage.region == region) - return query.all() + return paginate(query, page, page_size) + + +def paginate( + query: Query, + page: int = 1, + page_size: int = 100, +) -> dict: + """Paginate a query and return the results. + + Args: + query: SQLAlchemy query object + page (int): page number + page_size (int): number of items per page + + Returns: + dict: paginated results + """ + if page < 1: + page = 1 + + if page_size < 1: + page_size = 1 + + total_count = query.count() + total_pages = (total_count + page_size - 1) // page_size + + results = query.limit(page_size).offset((page - 1) * page_size).all() + + return { + "results": results, + "page": page, + "page_size": page_size, + "total_count": total_count, + "total_pages": total_pages, + } diff --git a/cid/main.py b/cid/main.py index 8aa3aa9..9b56936 100644 --- a/cid/main.py +++ b/cid/main.py @@ -39,9 +39,11 @@ def all_aws_images( version: Optional[str] = None, name: Optional[str] = None, region: Optional[str] = None, -) -> list: - result = crud.find_aws_images(db, arch, version, name, region) - return list(jsonable_encoder(result)) + page: int = 1, + page_size: int = 100, +) -> dict: + result = crud.find_aws_images(db, arch, version, name, region, page, page_size) + return dict(jsonable_encoder(result)) @app.get("/aws/latest") diff --git a/tests/test_crud.py b/tests/test_crud.py index 5f5c07b..a5285d7 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -522,28 +522,119 @@ def test_find_aws_images(db): db.commit() result = crud.find_aws_images(db, None, None, None, None) - assert len(result) == 5 + assert len(result["results"]) == 5 result = crud.find_aws_images(db, "arm64", None, None, None) - assert len(result) == 1 - assert result[0].name == "RHEL-9.5.0" - assert result[0].arch == "arm64" - assert result[0].region == "us-west-2" + assert len(result["results"]) == 1 + assert result["results"][0].name == "RHEL-9.5.0" + assert result["results"][0].arch == "arm64" + assert result["results"][0].region == "us-west-2" result = crud.find_aws_images(db, None, "9.5.0", None, None) - assert len(result) == 2 - assert result[0].name == "RHEL-9.5.0" - assert result[0].arch == "x86_64" - assert result[0].region == "us-west-1" - assert result[1].name == "RHEL-9.5.0" - assert result[1].arch == "arm64" - assert result[1].region == "us-west-2" + assert len(result["results"]) == 2 + assert result["results"][0].name == "RHEL-9.5.0" + assert result["results"][0].arch == "x86_64" + assert result["results"][0].region == "us-west-1" + assert result["results"][1].name == "RHEL-9.5.0" + assert result["results"][1].arch == "arm64" + assert result["results"][1].region == "us-west-2" result = crud.find_aws_images(db, None, None, "10.0.0", None) - assert len(result) == 1 - assert result[0].name == "RHEL-10.0.0" - assert result[0].arch == "x86_64" - assert result[0].region == "us-west-2" + assert len(result["results"]) == 1 + assert result["results"][0].name == "RHEL-10.0.0" + assert result["results"][0].arch == "x86_64" + assert result["results"][0].region == "us-west-2" result = crud.find_aws_images(db, None, None, None, "us-west-1") - assert len(result) == 3 + assert len(result["results"]) == 3 + + +def test_find_aws_images_paginated(db): + images = [ + AwsImage( + id="ami-a", + name="RHEL-8.2.0", + version="8.2.0", + arch="x86_64", + region="us-west-1", + ), + AwsImage( + id="ami-b", + name="RHEL-7.9.0", + version="7.9.0", + arch="x86_64", + region="us-west-1", + ), + AwsImage( + id="ami-c", + name="RHEL-9.5.0", + version="9.5.0", + arch="x86_64", + region="us-west-1", + ), + AwsImage( + id="ami-d", + name="RHEL-10.0.0", + version="10.0.0", + arch="x86_64", + region="us-west-2", + ), + AwsImage( + id="ami-e", + name="RHEL-9.5.0", + version="9.5.0", + arch="arm64", + region="us-west-2", + ), + ] + db.add_all(images) + db.commit() + + result = crud.find_aws_images(db, None, None, None, None) + assert len(result["results"]) == 5 + assert result["page"] == 1 + assert result["page_size"] == 100 + assert result["total_count"] == 5 + assert result["total_pages"] == 1 + + result = crud.find_aws_images(db, None, None, None, None, 1, 1) + assert len(result["results"]) == 1 + assert result["page"] == 1 + assert result["page_size"] == 1 + assert result["total_count"] == 5 + assert result["total_pages"] == 5 + + result = crud.find_aws_images(db, None, None, None, None, 2, 1) + assert len(result["results"]) == 1 + assert result["page"] == 2 + assert result["page_size"] == 1 + assert result["total_count"] == 5 + assert result["total_pages"] == 5 + + result = crud.find_aws_images(db, None, None, None, None, 6, 1) + assert len(result["results"]) == 0 + assert result["page"] == 6 + assert result["page_size"] == 1 + assert result["total_count"] == 5 + assert result["total_pages"] == 5 + + result = crud.find_aws_images(db, None, None, None, None, 1, 1000) + assert len(result["results"]) == 5 + assert result["page"] == 1 + assert result["page_size"] == 1000 + assert result["total_count"] == 5 + assert result["total_pages"] == 1 + + result = crud.find_aws_images(db, None, None, None, None, -1, 10) + assert len(result["results"]) == 5 + assert result["page"] == 1 + assert result["page_size"] == 10 + assert result["total_count"] == 5 + assert result["total_pages"] == 1 + + result = crud.find_aws_images(db, None, None, None, None, 1, -10) + assert len(result["results"]) == 1 + assert result["page"] == 1 + assert result["page_size"] == 1 + assert result["total_count"] == 5 + assert result["total_pages"] == 5 diff --git a/tests/test_main.py b/tests/test_main.py index 781981f..2bd9eed 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -123,32 +123,47 @@ def test_google_versions(mock_versions): def test_all_aws_images(): response = client.get("/aws") assert response.status_code == 200 - assert len(response.json()) == 500 + assert len(response.json()["results"]) == 100 + assert response.json()["page"] == 1 + assert response.json()["page_size"] == 100 + assert response.json()["total_count"] == 500 + assert response.json()["total_pages"] == 5 + + +def test_all_aws_images_paginated(): + response = client.get("/aws?page=2&page_size=1") + assert response.status_code == 200 + assert len(response.json()["results"]) == 1 + assert response.json()["page"] == 2 + assert response.json()["page_size"] == 1 + assert response.json()["total_count"] == 500 + assert response.json()["total_pages"] == 500 def test_all_aws_images_with_query(): response = client.get("/aws?version=9.4.0") assert response.status_code == 200 - assert response.json()[0]["version"] == "9.4.0" + assert response.json()["results"][0]["version"] == "9.4.0" def test_all_aws_images_with_query_region(): response = client.get("/aws?region=af-south-1") assert response.status_code == 200 - assert response.json()[0]["region"] == "af-south-1" + assert response.json()["results"][0]["region"] == "af-south-1" def test_all_aws_images_with_query_arch(): response = client.get("/aws?arch=x86_64") assert response.status_code == 200 - assert response.json()[0]["arch"] == "x86_64" + assert response.json()["results"][0]["arch"] == "x86_64" def test_all_aws_images_with_query_name(): response = client.get("/aws?name=RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3") assert response.status_code == 200 assert ( - response.json()[0]["name"] == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3" + response.json()["results"][0]["name"] + == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3" ) @@ -161,13 +176,14 @@ def test_all_aws_images_with_query_combination(): + "&arch=x86_64" ) assert response.status_code == 200 - assert len(response.json()) == 1 + assert len(response.json()["results"]) == 1 assert ( - response.json()[0]["name"] == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3" + response.json()["results"][0]["name"] + == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3" ) - assert response.json()[0]["region"] == "af-south-1" - assert response.json()[0]["version"] == "9.4.0" - assert response.json()[0]["arch"] == "x86_64" + assert response.json()["results"][0]["region"] == "af-south-1" + assert response.json()["results"][0]["version"] == "9.4.0" + assert response.json()["results"][0]["arch"] == "x86_64" def test_all_aws_images_with_query_combination_no_match(): @@ -179,7 +195,7 @@ def test_all_aws_images_with_query_combination_no_match(): + "&arch=arm64" ) assert response.status_code == 200 - assert len(response.json()) == 0 + assert len(response.json()["results"]) == 0 def test_single_aws_image():