diff --git a/a_sync/a_sync/_flags.pxd b/a_sync/a_sync/_flags.pxd index d8d74588..ab6c1bfa 100644 --- a/a_sync/a_sync/_flags.pxd +++ b/a_sync/a_sync/_flags.pxd @@ -1 +1 @@ -cdef bint negate_if_necessary(str flag, object flag_value) \ No newline at end of file +cdef bint negate_if_necessary(str flag, bint flag_value) \ No newline at end of file diff --git a/a_sync/a_sync/_flags.pyx b/a_sync/a_sync/_flags.pyx index 811dae53..a0078aa5 100644 --- a/a_sync/a_sync/_flags.pyx +++ b/a_sync/a_sync/_flags.pyx @@ -17,7 +17,7 @@ from a_sync import exceptions from a_sync.a_sync.flags import AFFIRMATIVE_FLAGS, NEGATIVE_FLAGS -cdef bint negate_if_necessary(str flag, object flag_value): +cdef bint negate_if_necessary(str flag, bint flag_value): """Negate the flag value if necessary based on the flag type. This function checks if the provided flag is in the set of affirmative or negative flags @@ -31,6 +31,7 @@ cdef bint negate_if_necessary(str flag, object flag_value): The potentially negated flag value. Raises: + TypeError: If the provided flag value is not of bint type. exceptions.InvalidFlag: If the flag is not recognized. Examples: @@ -43,12 +44,10 @@ cdef bint negate_if_necessary(str flag, object flag_value): See Also: - :func:`validate_flag_value`: Validates that the flag value is a boolean. """ - if not isinstance(flag_value, bool): - raise exceptions.InvalidFlagValue(flag, flag_value) if flag in AFFIRMATIVE_FLAGS: - return flag_value + return flag_value elif flag in NEGATIVE_FLAGS: - return not flag_value + return not flag_value raise exceptions.InvalidFlag(flag) @@ -72,11 +71,6 @@ cdef bint validate_flag_value(str flag, object flag_value): >>> validate_flag_value('sync', True) True - >>> validate_flag_value('asynchronous', 'yes') - Traceback (most recent call last): - ... - exceptions.InvalidFlagValue: Invalid flag value for 'asynchronous': 'yes' - See Also: - :func:`negate_if_necessary`: Negates the flag value if necessary based on the flag type. """ diff --git a/a_sync/a_sync/_kwargs.pyx b/a_sync/a_sync/_kwargs.pyx index 61d0e811..02f78d62 100644 --- a/a_sync/a_sync/_kwargs.pyx +++ b/a_sync/a_sync/_kwargs.pyx @@ -60,6 +60,10 @@ cdef bint is_sync(str flag, dict kwargs, bint pop_flag): :func:`get_flag_name`: Retrieves the name of the flag present in the kwargs. """ if pop_flag: + # NOTE: we should techincally raise InvalidFlagValue here but I dont want to set flag_value to a var return negate_if_necessary(flag, kwargs.pop(flag)) else: - return negate_if_necessary(flag, kwargs[flag]) \ No newline at end of file + try: + return negate_if_necessary(flag, kwargs[flag]) + except TypeError as e: + raise exceptions.InvalidFlagValue(flag, kwargs[flag]) from e.__cause__ \ No newline at end of file diff --git a/a_sync/a_sync/abstract.pyx b/a_sync/a_sync/abstract.pyx index 95e38df6..dbd70338 100644 --- a/a_sync/a_sync/abstract.pyx +++ b/a_sync/a_sync/abstract.pyx @@ -13,6 +13,7 @@ import abc import logging from typing import Dict, Any, Tuple +from a_sync import exceptions from a_sync._typing import * from a_sync.a_sync cimport _kwargs from a_sync.a_sync._flags cimport negate_if_necessary @@ -101,9 +102,16 @@ class ASyncABC(metaclass=ASyncMeta): ) if not cache.is_cached: - cache.value = negate_if_necessary( - self.__a_sync_flag_name__, self.__a_sync_flag_value__ - ) + try: + cache.value = negate_if_necessary( + self.__a_sync_flag_name__, self.__a_sync_flag_value__ + ) + except TypeError as e: + raise exceptions.InvalidFlagValue( + self.__a_sync_flag_name__, + self.__a_sync_flag_value__, + ) from e.__cause__ + cache.is_cached = True self.__a_sync_should_await_cache__ = cache return cache.value @@ -133,9 +141,16 @@ class ASyncABC(metaclass=ASyncMeta): ) if not cache.is_cached: - cache.value = negate_if_necessary( - self.__a_sync_flag_name__, self.__a_sync_flag_value__ - ) + try: + cache.value = negate_if_necessary( + self.__a_sync_flag_name__, self.__a_sync_flag_value__ + ) + except TypeError as e: + raise exceptions.InvalidFlagValue( + self.__a_sync_flag_name__, + self.__a_sync_flag_value__, + ) from e.__cause__ + cache.is_cached = True self.__a_sync_should_await_cache__ = cache return cache.value diff --git a/a_sync/a_sync/base.pyx b/a_sync/a_sync/base.pyx index a50ab12a..0cb24be4 100644 --- a/a_sync/a_sync/base.pyx +++ b/a_sync/a_sync/base.pyx @@ -118,7 +118,10 @@ class ASyncGenericBase(ASyncABC): except exceptions.NoFlagsFound: flag = _get_a_sync_flag_name_from_class_def(cls) flag_value = _get_a_sync_flag_value_from_class_def(cls, flag) - return negate_if_necessary(flag, flag_value) # type: ignore [arg-type] + try: + return negate_if_necessary(flag, flag_value) # type: ignore [arg-type] + except TypeError as e: + raise exceptions.InvalidFlagValue(flag, flag_value) from e.__cause__ # we need an extra var so we can log it cdef bint sync @@ -130,7 +133,11 @@ class ASyncGenericBase(ASyncABC): flag = _get_a_sync_flag_name_from_class_def(cls) flag_value = _get_a_sync_flag_value_from_class_def(cls, flag) - sync = negate_if_necessary(flag, flag_value) # type: ignore [arg-type] + try: + sync = negate_if_necessary(flag, flag_value) # type: ignore [arg-type] + except TypeError as e: + raise exceptions.InvalidFlagValue(flag, flag_value) from e.__cause__ + logger._log( logging.DEBUG, "`%s.%s` indicates default mode is %ssynchronous",