Skip to content

Commit

Permalink
refactor(batch): add from __future__ import annotations (#4993)
Browse files Browse the repository at this point in the history
* refactor(batch): add from __future__ import annotations

and update code according to ruff rules TCH, UP006, UP007, UP037 and
FA100.

* Fixing types in Python 3.8 and 3.9

---------

Co-authored-by: Leandro Damascena <[email protected]>
  • Loading branch information
ericbn and leandrodamascena authored Aug 17, 2024
1 parent 689072f commit 918844d
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 48 deletions.
57 changes: 32 additions & 25 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

"""
Batch processing utilities
"""
from __future__ import annotations

import asyncio
import copy
import inspect
Expand All @@ -11,22 +12,28 @@
import sys
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union, overload

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.utilities.batch.exceptions import (
BatchProcessingError,
ExceptionInfo,
)
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels, PartialItemFailureResponse, PartialItemFailures
from aws_lambda_powertools.utilities.batch.types import BatchTypeModels
from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import (
DynamoDBRecord,
)
from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import (
KinesisStreamRecord,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import (
PartialItemFailureResponse,
PartialItemFailures,
)
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand All @@ -41,7 +48,7 @@ class EventType(Enum):
# and depending on what EventType it's passed it'll correctly map to the right record
# When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis
EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord]
BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"]
BatchEventTypes = Union[EventSourceDataClassTypes, BatchTypeModels]
SuccessResponse = Tuple[str, Any, BatchEventTypes]
FailureResponse = Tuple[str, str, BatchEventTypes]

Expand All @@ -54,9 +61,9 @@ class BasePartialProcessor(ABC):
lambda_context: LambdaContext

def __init__(self):
self.success_messages: List[BatchEventTypes] = []
self.fail_messages: List[BatchEventTypes] = []
self.exceptions: List[ExceptionInfo] = []
self.success_messages: list[BatchEventTypes] = []
self.fail_messages: list[BatchEventTypes] = []
self.exceptions: list[ExceptionInfo] = []

@abstractmethod
def _prepare(self):
Expand All @@ -79,7 +86,7 @@ def _process_record(self, record: dict):
"""
raise NotImplementedError()

def process(self) -> List[Tuple]:
def process(self) -> list[tuple]:
"""
Call instance's handler for each record.
"""
Expand All @@ -92,7 +99,7 @@ async def _async_process_record(self, record: dict):
"""
raise NotImplementedError()

def async_process(self) -> List[Tuple]:
def async_process(self) -> list[tuple]:
"""
Async call instance's handler for each record.
Expand Down Expand Up @@ -135,13 +142,13 @@ def __enter__(self):
def __exit__(self, exception_type, exception_value, traceback):
self._clean()

def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None):
def __call__(self, records: list[dict], handler: Callable, lambda_context: LambdaContext | None = None):
"""
Set instance attributes before execution
Parameters
----------
records: List[dict]
records: list[dict]
List with objects to be processed.
handler: Callable
Callable to process "records" entries.
Expand Down Expand Up @@ -222,14 +229,14 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
class BasePartialBatchProcessor(BasePartialProcessor): # noqa
DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []}

def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None):
def __init__(self, event_type: EventType, model: BatchTypeModels | None = None):
"""Process batch and partially report failed items
Parameters
----------
event_type: EventType
Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event
model: Optional["BatchTypeModels"]
model: BatchTypeModels | None
Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord
Exceptions
Expand Down Expand Up @@ -294,7 +301,7 @@ def _has_messages_to_report(self) -> bool:
def _entire_batch_failed(self) -> bool:
return len(self.exceptions) == len(self.records)

def _get_messages_to_report(self) -> List[PartialItemFailures]:
def _get_messages_to_report(self) -> list[PartialItemFailures]:
"""
Format messages to use in batch deletion
"""
Expand Down Expand Up @@ -343,13 +350,13 @@ def _to_batch_type(
self,
record: dict,
event_type: EventType,
model: "BatchTypeModels",
) -> "BatchTypeModels": ... # pragma: no cover
model: BatchTypeModels,
) -> BatchTypeModels: ... # pragma: no cover

@overload
def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: ... # pragma: no cover

def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None):
def _to_batch_type(self, record: dict, event_type: EventType, model: BatchTypeModels | None = None):
if model is not None:
# If a model is provided, we assume Pydantic is installed and we need to disable v2 warnings
return model.model_validate(record)
Expand All @@ -363,7 +370,7 @@ def _register_model_validation_error_record(self, record: dict):
# and downstream we can correctly collect the correct message id identifier and make the failed record available
# see https://github.com/aws-powertools/powertools-lambda-python/issues/2091
logger.debug("Record cannot be converted to customer's model; converting without model")
failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type)
failed_record: EventSourceDataClassTypes = self._to_batch_type(record=record, event_type=self.event_type)
return self.failure_handler(record=failed_record, exception=sys.exc_info())


Expand Down Expand Up @@ -453,7 +460,7 @@ def record_handler(record: DynamoDBRecord):
logger.info(record.dynamodb.new_image)
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
# alternatively:
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# payload = change.get("Message") -> "<payload>"
...
Expand Down Expand Up @@ -481,7 +488,7 @@ def lambda_handler(event, context: LambdaContext):
async def _async_process_record(self, record: dict):
raise NotImplementedError()

def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
def _process_record(self, record: dict) -> SuccessResponse | FailureResponse:
"""
Process a record with instance's handler
Expand All @@ -490,7 +497,7 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons
record: dict
A batch record to be processed.
"""
data: Optional["BatchTypeModels"] = None
data: BatchTypeModels | None = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
Expand Down Expand Up @@ -602,7 +609,7 @@ async def record_handler(record: DynamoDBRecord):
logger.info(record.dynamodb.new_image)
payload: dict = json.loads(record.dynamodb.new_image.get("item"))
# alternatively:
# changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001
# payload = change.get("Message") -> "<payload>"
...
Expand Down Expand Up @@ -630,7 +637,7 @@ def lambda_handler(event, context: LambdaContext):
def _process_record(self, record: dict):
raise NotImplementedError()

async def _async_process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]:
async def _async_process_record(self, record: dict) -> SuccessResponse | FailureResponse:
"""
Process a record with instance's handler
Expand All @@ -639,7 +646,7 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa
record: dict
A batch record to be processed.
"""
data: Optional["BatchTypeModels"] = None
data: BatchTypeModels | None = None
try:
data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model)
if self._handler_accepts_lambda_context:
Expand Down
28 changes: 15 additions & 13 deletions aws_lambda_powertools/utilities/batch/decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any, Awaitable, Callable, Dict, List
from typing import TYPE_CHECKING, Any, Awaitable, Callable

from typing_extensions import deprecated

Expand All @@ -12,10 +12,12 @@
BatchProcessor,
EventType,
)
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse
from aws_lambda_powertools.utilities.typing import LambdaContext


@lambda_handler_decorator
@deprecated(
Expand All @@ -24,7 +26,7 @@
)
def async_batch_processor(
handler: Callable,
event: Dict,
event: dict,
context: LambdaContext,
record_handler: Callable[..., Awaitable[Any]],
processor: AsyncBatchProcessor,
Expand All @@ -40,7 +42,7 @@ def async_batch_processor(
----------
handler: Callable
Lambda's handler
event: Dict
event: dict
Lambda's Event
context: LambdaContext
Lambda's Context
Expand Down Expand Up @@ -92,7 +94,7 @@ def async_batch_processor(
)
def batch_processor(
handler: Callable,
event: Dict,
event: dict,
context: LambdaContext,
record_handler: Callable,
processor: BatchProcessor,
Expand All @@ -108,7 +110,7 @@ def batch_processor(
----------
handler: Callable
Lambda's handler
event: Dict
event: dict
Lambda's Event
context: LambdaContext
Lambda's Context
Expand Down Expand Up @@ -154,7 +156,7 @@ def batch_processor(


def process_partial_response(
event: Dict,
event: dict,
record_handler: Callable,
processor: BasePartialBatchProcessor,
context: LambdaContext | None = None,
Expand All @@ -164,7 +166,7 @@ def process_partial_response(
Parameters
----------
event: Dict
event: dict
Lambda's original event
record_handler: Callable
Callable to process each record from the batch
Expand Down Expand Up @@ -202,7 +204,7 @@ def handler(event, context):
* Async batch processors. Use `async_process_partial_response` instead.
"""
try:
records: List[Dict] = event.get("Records", [])
records: list[dict] = event.get("Records", [])
except AttributeError:
event_types = ", ".join(list(EventType.__members__))
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line
Expand All @@ -218,7 +220,7 @@ def handler(event, context):


def async_process_partial_response(
event: Dict,
event: dict,
record_handler: Callable,
processor: AsyncBatchProcessor,
context: LambdaContext | None = None,
Expand All @@ -228,7 +230,7 @@ def async_process_partial_response(
Parameters
----------
event: Dict
event: dict
Lambda's original event
record_handler: Callable
Callable to process each record from the batch
Expand Down Expand Up @@ -266,7 +268,7 @@ def handler(event, context):
* Sync batch processors. Use `process_partial_response` instead.
"""
try:
records: List[Dict] = event.get("Records", [])
records: list[dict] = event.get("Records", [])
except AttributeError:
event_types = ", ".join(list(EventType.__members__))
docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line
Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/utilities/batch/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import traceback
from types import TracebackType
from typing import List, Optional, Tuple, Type
from typing import Optional, Tuple, Type

ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]


class BaseBatchProcessingError(Exception):
def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
super().__init__(msg)
self.msg = msg
self.child_exceptions = child_exceptions or []
Expand All @@ -30,7 +30,7 @@ def format_exceptions(self, parent_exception_str):
class BatchProcessingError(BaseBatchProcessingError):
"""When all batch records failed to be processed"""

def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None):
def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None):
super().__init__(msg, child_exceptions)

def __str__(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from typing import Optional, Set
from typing import TYPE_CHECKING

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse
from aws_lambda_powertools.utilities.batch.exceptions import (
SQSFifoCircuitBreakerError,
SQSFifoMessageGroupCircuitBreakerError,
)
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,13 +66,13 @@ def lambda_handler(event, context: LambdaContext):
None,
)

def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False):
def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
"""
Initialize the SqsFifoProcessor.
Parameters
----------
model: Optional["BatchSqsTypeModel"]
model: BatchSqsTypeModel | None
An optional model for batch processing.
skip_group_on_error: bool
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
Expand All @@ -77,7 +81,7 @@ def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_er
"""
self._skip_group_on_error: bool = skip_group_on_error
self._current_group_id = None
self._failed_group_ids: Set[str] = set()
self._failed_group_ids: set[str] = set()
super().__init__(EventType.SQS, model)

def _process_record(self, record):
Expand Down
Loading

0 comments on commit 918844d

Please sign in to comment.