Skip to content

Commit

Permalink
fix(appsync): make contextual data accessible for async functions (#5317
Browse files Browse the repository at this point in the history
)

* Making contextual data accessible for async functions

* V4 comment

* Reverting test
  • Loading branch information
leandrodamascena authored Oct 24, 2024
1 parent c11d25c commit ccfbc94
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 3 deletions.
8 changes: 7 additions & 1 deletion aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ def lambda_handler(event, context):
Router.current_event = data_model(event)
response = self._call_single_resolver(event=event, data_model=data_model)

self.clear_context()
# We don't clear the context for coroutines because we don't have control over the event loop.
# If we clean the context immediately, it might not be available when the coroutine is actually executed.
# For single async operations, the context should be cleaned up manually after the coroutine completes.
# See: https://github.com/aws-powertools/powertools-lambda-python/issues/5290
# REVIEW: Review this support in Powertools V4
if not asyncio.iscoroutine(response):
self.clear_context()

return response

Expand Down
4 changes: 2 additions & 2 deletions docs/core/event_handler/appsync.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ Let's assume you have `split_operation.py` as your Lambda function entrypoint an

You can use `append_context` when you want to share data between your App and Router instances. Any data you share will be available via the `context` dictionary available in your App or Router context.

???+ info
For safety, we always clear any data available in the `context` dictionary after each invocation.
???+ warning
For safety, we clear the context after each invocation, except for async single resolvers. For these, use `app.context.clear()` before returning the function.

???+ tip
This can also be useful for middlewares injecting contextual information before a request is processed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,41 @@ def get_user(event: List) -> List:

# THEN the resolver must be able to return a field in the batch_current_event
assert ret[0] == mock_event[0]["identity"]["sub"]


def test_context_is_accessible_in_sync_batch_resolver():
mock_event = load_event("appSyncBatchEvent.json")

# GIVEN An instance of AppSyncResolver and a resolver function registered with the app
app = AppSyncResolver()

@app.batch_resolver(field_name="createSomething")
def get_user(event: List) -> List:
return [app.context.get("project_name")]

# WHEN we resolve the event
app.append_context(project_name="powertools")
ret = app.resolve(mock_event, {})

# THEN the resolver must be able to return a field in the batch_current_event
assert app.context == {}
assert ret[0] == "powertools"


def test_context_is_accessible_in_async_batch_resolver():
mock_event = load_event("appSyncBatchEvent.json")

# GIVEN An instance of AppSyncResolver and a resolver function registered with the app
app = AppSyncResolver()

@app.async_batch_resolver(field_name="createSomething")
async def get_user(event: List) -> List:
return [app.context.get("project_name")]

# WHEN we resolve the event
app.append_context(project_name="powertools")
ret = app.resolve(mock_event, {})

# THEN the resolver must be able to return a field in the batch_current_event
assert app.context == {}
assert ret[0] == "powertools"
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,43 @@ def get_user(id: str) -> dict: # noqa AA03 VNE003

# THEN the resolver must be able to return a field in the current_event
assert ret == mock_event["identity"]["sub"]


def test_route_context_is_not_cleared_after_resolve_async():
# GIVEN
app = AppSyncResolver()
event = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}}

@app.resolver(field_name="listLocations")
async def get_locations(name: str):
return f"get_locations#{name}"

# WHEN event resolution kicks in
app.append_context(is_admin=True)
app.resolve(event, {})

# THEN context should be empty
assert app.context == {"is_admin": True}


def test_route_context_is_manually_cleared_after_resolve_async():
# GIVEN
# GIVEN
app = AppSyncResolver()

mock_event = {"typeName": "Customer", "fieldName": "field", "arguments": {}}

@app.resolver(field_name="field")
async def get_async():
app.context.clear()
await asyncio.sleep(0.0001)
return "value"

# WHEN
mock_context = LambdaContext()
app.append_context(is_admin=True)
result = app.resolve(mock_event, mock_context)

# THEN
assert asyncio.run(result) == "value"
assert app.context == {}

0 comments on commit ccfbc94

Please sign in to comment.