From 918844d75818217fba75346277f19662e00989f9 Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Sat, 17 Aug 2024 16:46:09 -0500 Subject: [PATCH] refactor(batch): add from __future__ import annotations (#4993) * refactor(batch): add from __future__ import annotations and update code according to ruff rules TCH, UP006, UP007, UP037 and FA100. * Fixing types in Python 3.8 and 3.9 --------- Co-authored-by: Leandro Damascena --- aws_lambda_powertools/utilities/batch/base.py | 57 +++++++++++-------- .../utilities/batch/decorators.py | 28 ++++----- .../utilities/batch/exceptions.py | 6 +- .../batch/sqs_fifo_partial_processor.py | 14 +++-- .../utilities/batch/types.py | 6 +- 5 files changed, 63 insertions(+), 48 deletions(-) diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 74f9ddc4796..b4756db8b72 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - """ Batch processing utilities """ +from __future__ import annotations + import asyncio import copy import inspect @@ -11,14 +12,14 @@ import sys from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Tuple, Union, overload from aws_lambda_powertools.shared import constants from aws_lambda_powertools.utilities.batch.exceptions import ( BatchProcessingError, ExceptionInfo, ) -from aws_lambda_powertools.utilities.batch.types import BatchTypeModels, PartialItemFailureResponse, PartialItemFailures +from aws_lambda_powertools.utilities.batch.types import BatchTypeModels from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecord, ) @@ -26,7 +27,13 @@ KinesisStreamRecord, ) from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord -from aws_lambda_powertools.utilities.typing import LambdaContext + +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.batch.types import ( + PartialItemFailureResponse, + PartialItemFailures, + ) + from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -41,7 +48,7 @@ class EventType(Enum): # and depending on what EventType it's passed it'll correctly map to the right record # When using Pydantic Models, it'll accept any subclass from SQS, DynamoDB and Kinesis EventSourceDataClassTypes = Union[SQSRecord, KinesisStreamRecord, DynamoDBRecord] -BatchEventTypes = Union[EventSourceDataClassTypes, "BatchTypeModels"] +BatchEventTypes = Union[EventSourceDataClassTypes, BatchTypeModels] SuccessResponse = Tuple[str, Any, BatchEventTypes] FailureResponse = Tuple[str, str, BatchEventTypes] @@ -54,9 +61,9 @@ class BasePartialProcessor(ABC): lambda_context: LambdaContext def __init__(self): - self.success_messages: List[BatchEventTypes] = [] - self.fail_messages: List[BatchEventTypes] = [] - self.exceptions: List[ExceptionInfo] = [] + self.success_messages: list[BatchEventTypes] = [] + self.fail_messages: list[BatchEventTypes] = [] + self.exceptions: list[ExceptionInfo] = [] @abstractmethod def _prepare(self): @@ -79,7 +86,7 @@ def _process_record(self, record: dict): """ raise NotImplementedError() - def process(self) -> List[Tuple]: + def process(self) -> list[tuple]: """ Call instance's handler for each record. """ @@ -92,7 +99,7 @@ async def _async_process_record(self, record: dict): """ raise NotImplementedError() - def async_process(self) -> List[Tuple]: + def async_process(self) -> list[tuple]: """ Async call instance's handler for each record. @@ -135,13 +142,13 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self._clean() - def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None): + def __call__(self, records: list[dict], handler: Callable, lambda_context: LambdaContext | None = None): """ Set instance attributes before execution Parameters ---------- - records: List[dict] + records: list[dict] List with objects to be processed. handler: Callable Callable to process "records" entries. @@ -222,14 +229,14 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: class BasePartialBatchProcessor(BasePartialProcessor): # noqa DEFAULT_RESPONSE: PartialItemFailureResponse = {"batchItemFailures": []} - def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = None): + def __init__(self, event_type: EventType, model: BatchTypeModels | None = None): """Process batch and partially report failed items Parameters ---------- event_type: EventType Whether this is a SQS, DynamoDB Streams, or Kinesis Data Stream event - model: Optional["BatchTypeModels"] + model: BatchTypeModels | None Parser's data model using either SqsRecordModel, DynamoDBStreamRecordModel, KinesisDataStreamRecord Exceptions @@ -294,7 +301,7 @@ def _has_messages_to_report(self) -> bool: def _entire_batch_failed(self) -> bool: return len(self.exceptions) == len(self.records) - def _get_messages_to_report(self) -> List[PartialItemFailures]: + def _get_messages_to_report(self) -> list[PartialItemFailures]: """ Format messages to use in batch deletion """ @@ -343,13 +350,13 @@ def _to_batch_type( self, record: dict, event_type: EventType, - model: "BatchTypeModels", - ) -> "BatchTypeModels": ... # pragma: no cover + model: BatchTypeModels, + ) -> BatchTypeModels: ... # pragma: no cover @overload def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: ... # pragma: no cover - def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): + def _to_batch_type(self, record: dict, event_type: EventType, model: BatchTypeModels | None = None): if model is not None: # If a model is provided, we assume Pydantic is installed and we need to disable v2 warnings return model.model_validate(record) @@ -363,7 +370,7 @@ def _register_model_validation_error_record(self, record: dict): # and downstream we can correctly collect the correct message id identifier and make the failed record available # see https://github.com/aws-powertools/powertools-lambda-python/issues/2091 logger.debug("Record cannot be converted to customer's model; converting without model") - failed_record: "EventSourceDataClassTypes" = self._to_batch_type(record=record, event_type=self.event_type) + failed_record: EventSourceDataClassTypes = self._to_batch_type(record=record, event_type=self.event_type) return self.failure_handler(record=failed_record, exception=sys.exc_info()) @@ -453,7 +460,7 @@ def record_handler(record: DynamoDBRecord): logger.info(record.dynamodb.new_image) payload: dict = json.loads(record.dynamodb.new_image.get("item")) # alternatively: - # changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001 + # changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001 # payload = change.get("Message") -> "" ... @@ -481,7 +488,7 @@ def lambda_handler(event, context: LambdaContext): async def _async_process_record(self, record: dict): raise NotImplementedError() - def _process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: + def _process_record(self, record: dict) -> SuccessResponse | FailureResponse: """ Process a record with instance's handler @@ -490,7 +497,7 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons record: dict A batch record to be processed. """ - data: Optional["BatchTypeModels"] = None + data: BatchTypeModels | None = None try: data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) if self._handler_accepts_lambda_context: @@ -602,7 +609,7 @@ async def record_handler(record: DynamoDBRecord): logger.info(record.dynamodb.new_image) payload: dict = json.loads(record.dynamodb.new_image.get("item")) # alternatively: - # changes: Dict[str, Any] = record.dynamodb.new_image # noqa: ERA001 + # changes: dict[str, Any] = record.dynamodb.new_image # noqa: ERA001 # payload = change.get("Message") -> "" ... @@ -630,7 +637,7 @@ def lambda_handler(event, context: LambdaContext): def _process_record(self, record: dict): raise NotImplementedError() - async def _async_process_record(self, record: dict) -> Union[SuccessResponse, FailureResponse]: + async def _async_process_record(self, record: dict) -> SuccessResponse | FailureResponse: """ Process a record with instance's handler @@ -639,7 +646,7 @@ async def _async_process_record(self, record: dict) -> Union[SuccessResponse, Fa record: dict A batch record to be processed. """ - data: Optional["BatchTypeModels"] = None + data: BatchTypeModels | None = None try: data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) if self._handler_accepts_lambda_context: diff --git a/aws_lambda_powertools/utilities/batch/decorators.py b/aws_lambda_powertools/utilities/batch/decorators.py index e24c1159205..f23d64d0ce3 100644 --- a/aws_lambda_powertools/utilities/batch/decorators.py +++ b/aws_lambda_powertools/utilities/batch/decorators.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Awaitable, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing_extensions import deprecated @@ -12,10 +12,12 @@ BatchProcessor, EventType, ) -from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse -from aws_lambda_powertools.utilities.typing import LambdaContext from aws_lambda_powertools.warnings import PowertoolsDeprecationWarning +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.batch.types import PartialItemFailureResponse + from aws_lambda_powertools.utilities.typing import LambdaContext + @lambda_handler_decorator @deprecated( @@ -24,7 +26,7 @@ ) def async_batch_processor( handler: Callable, - event: Dict, + event: dict, context: LambdaContext, record_handler: Callable[..., Awaitable[Any]], processor: AsyncBatchProcessor, @@ -40,7 +42,7 @@ def async_batch_processor( ---------- handler: Callable Lambda's handler - event: Dict + event: dict Lambda's Event context: LambdaContext Lambda's Context @@ -92,7 +94,7 @@ def async_batch_processor( ) def batch_processor( handler: Callable, - event: Dict, + event: dict, context: LambdaContext, record_handler: Callable, processor: BatchProcessor, @@ -108,7 +110,7 @@ def batch_processor( ---------- handler: Callable Lambda's handler - event: Dict + event: dict Lambda's Event context: LambdaContext Lambda's Context @@ -154,7 +156,7 @@ def batch_processor( def process_partial_response( - event: Dict, + event: dict, record_handler: Callable, processor: BasePartialBatchProcessor, context: LambdaContext | None = None, @@ -164,7 +166,7 @@ def process_partial_response( Parameters ---------- - event: Dict + event: dict Lambda's original event record_handler: Callable Callable to process each record from the batch @@ -202,7 +204,7 @@ def handler(event, context): * Async batch processors. Use `async_process_partial_response` instead. """ try: - records: List[Dict] = event.get("Records", []) + records: list[dict] = event.get("Records", []) except AttributeError: event_types = ", ".join(list(EventType.__members__)) docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line @@ -218,7 +220,7 @@ def handler(event, context): def async_process_partial_response( - event: Dict, + event: dict, record_handler: Callable, processor: AsyncBatchProcessor, context: LambdaContext | None = None, @@ -228,7 +230,7 @@ def async_process_partial_response( Parameters ---------- - event: Dict + event: dict Lambda's original event record_handler: Callable Callable to process each record from the batch @@ -266,7 +268,7 @@ def handler(event, context): * Sync batch processors. Use `process_partial_response` instead. """ try: - records: List[Dict] = event.get("Records", []) + records: list[dict] = event.get("Records", []) except AttributeError: event_types = ", ".join(list(EventType.__members__)) docs = "https://docs.powertools.aws.dev/lambda/python/latest/utilities/batch/#processing-messages-from-sqs" # noqa: E501 # long-line diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index 3f4075c7d2f..2a501e034ce 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -6,13 +6,13 @@ import traceback from types import TracebackType -from typing import List, Optional, Tuple, Type +from typing import Optional, Tuple, Type ExceptionInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]] class BaseBatchProcessingError(Exception): - def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None): + def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None): super().__init__(msg) self.msg = msg self.child_exceptions = child_exceptions or [] @@ -30,7 +30,7 @@ def format_exceptions(self, parent_exception_str): class BatchProcessingError(BaseBatchProcessingError): """When all batch records failed to be processed""" - def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None): + def __init__(self, msg="", child_exceptions: list[ExceptionInfo] | None = None): super().__init__(msg, child_exceptions) def __str__(self): diff --git a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py index e54389718bc..d493e43bd93 100644 --- a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py +++ b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import logging -from typing import Optional, Set +from typing import TYPE_CHECKING from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse from aws_lambda_powertools.utilities.batch.exceptions import ( SQSFifoCircuitBreakerError, SQSFifoMessageGroupCircuitBreakerError, ) -from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel + +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel logger = logging.getLogger(__name__) @@ -62,13 +66,13 @@ def lambda_handler(event, context: LambdaContext): None, ) - def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False): + def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False): """ Initialize the SqsFifoProcessor. Parameters ---------- - model: Optional["BatchSqsTypeModel"] + model: BatchSqsTypeModel | None An optional model for batch processing. skip_group_on_error: bool Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures @@ -77,7 +81,7 @@ def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_er """ self._skip_group_on_error: bool = skip_group_on_error self._current_group_id = None - self._failed_group_ids: Set[str] = set() + self._failed_group_ids: set[str] = set() super().__init__(EventType.SQS, model) def _process_record(self, record): diff --git a/aws_lambda_powertools/utilities/batch/types.py b/aws_lambda_powertools/utilities/batch/types.py index d737480bf8f..ac0a7d73efa 100644 --- a/aws_lambda_powertools/utilities/batch/types.py +++ b/aws_lambda_powertools/utilities/batch/types.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import sys -from typing import List, Optional, Type, TypedDict, Union +from typing import Optional, Type, TypedDict, Union has_pydantic = "pydantic" in sys.modules @@ -25,4 +27,4 @@ class PartialItemFailures(TypedDict): class PartialItemFailureResponse(TypedDict): - batchItemFailures: List[PartialItemFailures] + batchItemFailures: list[PartialItemFailures]