Skip to content

Commit

Permalink
Add pagination to /aws
Browse files Browse the repository at this point in the history
Add paginate method to crud.py
Add optional pagination params to /aws
Add necessary tests
  • Loading branch information
F-X64 committed Jul 2, 2024
1 parent 7011f8c commit f081789
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 40 deletions.
58 changes: 49 additions & 9 deletions cid/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from packaging.version import Version
from sqlalchemy import desc
from sqlalchemy.orm import Session
from sqlalchemy.orm.query import Query

from cid.config import CLOUD_PROVIDERS
from cid.database import engine
Expand Down Expand Up @@ -376,18 +377,22 @@ def find_aws_images(
version: Optional[str] = None,
name: Optional[str] = None,
region: Optional[str] = None,
) -> list:
"""Return all AWS images that match the given criteria.
page: int = 1,
page_size: int = 100,
) -> dict:
"""Return paginated AWS images that match the given criteria.
Args:
db (Session): database session
arch (Optional[str]): architecture to search
version (Optional[str]): RHEL version to search
name (Optional[str]): image name to search
region (Optional[str]): AWS region to search
db (Session): database session
arch (Optional[str]): architecture to search
version (Optional[str]): RHEL version to search
name (Optional[str]): image name to search
region (Optional[str]): AWS region to search
page (int): page number
page_size (int): number of images per page
Returns:
list: list of images that match the given criteria
list: list of images that match the given criteria
"""
query = db.query(AwsImage).order_by(AwsImage.creationDate.desc())

Expand All @@ -400,4 +405,39 @@ def find_aws_images(
if region:
query = query.filter(AwsImage.region == region)

return query.all()
return paginate(query, page, page_size)


def paginate(
query: Query,
page: int = 1,
page_size: int = 100,
) -> dict:
"""Paginate a query and return the results.
Args:
query: SQLAlchemy query object
page (int): page number
page_size (int): number of items per page
Returns:
dict: paginated results
"""
if page < 1:
page = 1

if page_size < 1:
page_size = 1

total_count = query.count()
total_pages = (total_count + page_size - 1) // page_size

results = query.limit(page_size).offset((page - 1) * page_size).all()

return {
"results": results,
"page": page,
"page_size": page_size,
"total_count": total_count,
"total_pages": total_pages,
}
8 changes: 5 additions & 3 deletions cid/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def all_aws_images(
version: Optional[str] = None,
name: Optional[str] = None,
region: Optional[str] = None,
) -> list:
result = crud.find_aws_images(db, arch, version, name, region)
return list(jsonable_encoder(result))
page: int = 1,
page_size: int = 100,
) -> dict:
result = crud.find_aws_images(db, arch, version, name, region, page, page_size)
return dict(jsonable_encoder(result))


@app.get("/aws/latest")
Expand Down
125 changes: 108 additions & 17 deletions tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,28 +522,119 @@ def test_find_aws_images(db):
db.commit()

result = crud.find_aws_images(db, None, None, None, None)
assert len(result) == 5
assert len(result["results"]) == 5

result = crud.find_aws_images(db, "arm64", None, None, None)
assert len(result) == 1
assert result[0].name == "RHEL-9.5.0"
assert result[0].arch == "arm64"
assert result[0].region == "us-west-2"
assert len(result["results"]) == 1
assert result["results"][0].name == "RHEL-9.5.0"
assert result["results"][0].arch == "arm64"
assert result["results"][0].region == "us-west-2"

result = crud.find_aws_images(db, None, "9.5.0", None, None)
assert len(result) == 2
assert result[0].name == "RHEL-9.5.0"
assert result[0].arch == "x86_64"
assert result[0].region == "us-west-1"
assert result[1].name == "RHEL-9.5.0"
assert result[1].arch == "arm64"
assert result[1].region == "us-west-2"
assert len(result["results"]) == 2
assert result["results"][0].name == "RHEL-9.5.0"
assert result["results"][0].arch == "x86_64"
assert result["results"][0].region == "us-west-1"
assert result["results"][1].name == "RHEL-9.5.0"
assert result["results"][1].arch == "arm64"
assert result["results"][1].region == "us-west-2"

result = crud.find_aws_images(db, None, None, "10.0.0", None)
assert len(result) == 1
assert result[0].name == "RHEL-10.0.0"
assert result[0].arch == "x86_64"
assert result[0].region == "us-west-2"
assert len(result["results"]) == 1
assert result["results"][0].name == "RHEL-10.0.0"
assert result["results"][0].arch == "x86_64"
assert result["results"][0].region == "us-west-2"

result = crud.find_aws_images(db, None, None, None, "us-west-1")
assert len(result) == 3
assert len(result["results"]) == 3


def test_find_aws_images_paginated(db):
images = [
AwsImage(
id="ami-a",
name="RHEL-8.2.0",
version="8.2.0",
arch="x86_64",
region="us-west-1",
),
AwsImage(
id="ami-b",
name="RHEL-7.9.0",
version="7.9.0",
arch="x86_64",
region="us-west-1",
),
AwsImage(
id="ami-c",
name="RHEL-9.5.0",
version="9.5.0",
arch="x86_64",
region="us-west-1",
),
AwsImage(
id="ami-d",
name="RHEL-10.0.0",
version="10.0.0",
arch="x86_64",
region="us-west-2",
),
AwsImage(
id="ami-e",
name="RHEL-9.5.0",
version="9.5.0",
arch="arm64",
region="us-west-2",
),
]
db.add_all(images)
db.commit()

result = crud.find_aws_images(db, None, None, None, None)
assert len(result["results"]) == 5
assert result["page"] == 1
assert result["page_size"] == 100
assert result["total_count"] == 5
assert result["total_pages"] == 1

result = crud.find_aws_images(db, None, None, None, None, 1, 1)
assert len(result["results"]) == 1
assert result["page"] == 1
assert result["page_size"] == 1
assert result["total_count"] == 5
assert result["total_pages"] == 5

result = crud.find_aws_images(db, None, None, None, None, 2, 1)
assert len(result["results"]) == 1
assert result["page"] == 2
assert result["page_size"] == 1
assert result["total_count"] == 5
assert result["total_pages"] == 5

result = crud.find_aws_images(db, None, None, None, None, 6, 1)
assert len(result["results"]) == 0
assert result["page"] == 6
assert result["page_size"] == 1
assert result["total_count"] == 5
assert result["total_pages"] == 5

result = crud.find_aws_images(db, None, None, None, None, 1, 1000)
assert len(result["results"]) == 5
assert result["page"] == 1
assert result["page_size"] == 1000
assert result["total_count"] == 5
assert result["total_pages"] == 1

result = crud.find_aws_images(db, None, None, None, None, -1, 10)
assert len(result["results"]) == 5
assert result["page"] == 1
assert result["page_size"] == 10
assert result["total_count"] == 5
assert result["total_pages"] == 1

result = crud.find_aws_images(db, None, None, None, None, 1, -10)
assert len(result["results"]) == 1
assert result["page"] == 1
assert result["page_size"] == 1
assert result["total_count"] == 5
assert result["total_pages"] == 5
38 changes: 27 additions & 11 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,47 @@ def test_google_versions(mock_versions):
def test_all_aws_images():
response = client.get("/aws")
assert response.status_code == 200
assert len(response.json()) == 500
assert len(response.json()["results"]) == 100
assert response.json()["page"] == 1
assert response.json()["page_size"] == 100
assert response.json()["total_count"] == 500
assert response.json()["total_pages"] == 5


def test_all_aws_images_paginated():
response = client.get("/aws?page=2&page_size=1")
assert response.status_code == 200
assert len(response.json()["results"]) == 1
assert response.json()["page"] == 2
assert response.json()["page_size"] == 1
assert response.json()["total_count"] == 500
assert response.json()["total_pages"] == 500


def test_all_aws_images_with_query():
response = client.get("/aws?version=9.4.0")
assert response.status_code == 200
assert response.json()[0]["version"] == "9.4.0"
assert response.json()["results"][0]["version"] == "9.4.0"


def test_all_aws_images_with_query_region():
response = client.get("/aws?region=af-south-1")
assert response.status_code == 200
assert response.json()[0]["region"] == "af-south-1"
assert response.json()["results"][0]["region"] == "af-south-1"


def test_all_aws_images_with_query_arch():
response = client.get("/aws?arch=x86_64")
assert response.status_code == 200
assert response.json()[0]["arch"] == "x86_64"
assert response.json()["results"][0]["arch"] == "x86_64"


def test_all_aws_images_with_query_name():
response = client.get("/aws?name=RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3")
assert response.status_code == 200
assert (
response.json()[0]["name"] == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3"
response.json()["results"][0]["name"]
== "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3"
)


Expand All @@ -161,13 +176,14 @@ def test_all_aws_images_with_query_combination():
+ "&arch=x86_64"
)
assert response.status_code == 200
assert len(response.json()) == 1
assert len(response.json()["results"]) == 1
assert (
response.json()[0]["name"] == "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3"
response.json()["results"][0]["name"]
== "RHEL_HA-9.4.0_HVM-20240605-x86_64-82-Hourly2-GP3"
)
assert response.json()[0]["region"] == "af-south-1"
assert response.json()[0]["version"] == "9.4.0"
assert response.json()[0]["arch"] == "x86_64"
assert response.json()["results"][0]["region"] == "af-south-1"
assert response.json()["results"][0]["version"] == "9.4.0"
assert response.json()["results"][0]["arch"] == "x86_64"


def test_all_aws_images_with_query_combination_no_match():
Expand All @@ -179,7 +195,7 @@ def test_all_aws_images_with_query_combination_no_match():
+ "&arch=arm64"
)
assert response.status_code == 200
assert len(response.json()) == 0
assert len(response.json()["results"]) == 0


def test_single_aws_image():
Expand Down

0 comments on commit f081789

Please sign in to comment.