Skip to content

Commit

Permalink
refactor: rename variable, improve duplicate id error message
Browse files Browse the repository at this point in the history
  • Loading branch information
ewanharris committed Dec 18, 2024
1 parent 0a02932 commit b1feeda
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
12 changes: 7 additions & 5 deletions openfga_sdk/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,18 +707,20 @@ async def batch_check(self, body: ClientBatchCheckRequest, options=None):
elif isinstance(options["max_batch_size"], int):
max_batch_size = options["max_batch_size"]

check_to_id: dict[str, ClientBatchCheckItem] = {}
id_to_check: dict[str, ClientBatchCheckItem] = {}

def track_and_transform(checks):
transformed = []
for check in checks:
if check.correlation_id is None:
check.correlation_id = str(uuid.uuid4())

if check.correlation_id in check_to_id:
raise FgaValidationException("Duplicate correlation_id provided")
if check.correlation_id in id_to_check:
raise FgaValidationException(
f"Duplicate correlation_id ({check.correlation_id}) provided"
)

check_to_id[check.correlation_id] = check
id_to_check[check.correlation_id] = check

transformed.append(construct_batch_item(check))
return transformed
Expand All @@ -734,7 +736,7 @@ def track_and_transform(checks):
sem = asyncio.Semaphore(max_parallel_requests)

def map_response(id, result):
check = check_to_id[id]
check = id_to_check[id]
return ClientBatchCheckSingleResponse(
allowed=result.allowed,
request=check,
Expand Down
12 changes: 7 additions & 5 deletions openfga_sdk/sync/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,18 +693,20 @@ def batch_check(self, body: ClientBatchCheckRequest, options=None):
elif isinstance(options["max_batch_size"], int):
max_batch_size = options["max_batch_size"]

check_to_id: dict[str, ClientBatchCheckItem] = {}
id_to_check: dict[str, ClientBatchCheckItem] = {}

def track_and_transform(checks):
transformed = []
for check in checks:
if check.correlation_id is None:
check.correlation_id = str(uuid.uuid4())

if check.correlation_id in check_to_id:
raise FgaValidationException("Duplicate correlation_id provided")
if check.correlation_id in id_to_check:
raise FgaValidationException(
f"Duplicate correlation_id ({check.correlation_id}) provided"
)

check_to_id[check.correlation_id] = check
id_to_check[check.correlation_id] = check

transformed.append(construct_batch_item(check))
return transformed
Expand All @@ -717,7 +719,7 @@ def track_and_transform(checks):
]

def map_response(id, result):
check = check_to_id[id]
check = id_to_check[id]
return ClientBatchCheckSingleResponse(
allowed=result.allowed,
request=check,
Expand Down
5 changes: 4 additions & 1 deletion test/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2415,11 +2415,14 @@ async def test_batch_check_errors_dupe_cor_id(self):
configuration = self.configuration
configuration.store_id = store_id
async with OpenFgaClient(configuration) as api_client:
with self.assertRaises(FgaValidationException):
with self.assertRaises(FgaValidationException) as error:
await api_client.batch_check(
body=body,
options={"authorization_model_id": "01GXSA8YR785C4FYS3C0RTG7B1"},
)
self.assertEqual(
"Duplicate correlation_id (1) provided", str(error.exception)
)
await api_client.close()

@patch.object(rest.RESTClientObject, "request")
Expand Down
5 changes: 4 additions & 1 deletion test/sync/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,11 +2417,14 @@ def test_batch_check_errors_dupe_cor_id(self):
configuration = self.configuration
configuration.store_id = store_id
with OpenFgaClient(configuration) as api_client:
with self.assertRaises(FgaValidationException):
with self.assertRaises(FgaValidationException) as error:
api_client.batch_check(
body=body,
options={"authorization_model_id": "01GXSA8YR785C4FYS3C0RTG7B1"},
)
self.assertEqual(
"Duplicate correlation_id (1) provided", str(error.exception)
)
api_client.close()

@patch.object(rest.RESTClientObject, "request")
Expand Down

0 comments on commit b1feeda

Please sign in to comment.