Skip to content

Commit

Permalink
Merge pull request #47 from swansonk14/complex-types
Browse files Browse the repository at this point in the history
Complex types
  • Loading branch information
swansonk14 authored Mar 27, 2021
2 parents 0dc69b3 + ef08776 commit b51d06c
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 54 deletions.
2 changes: 1 addition & 1 deletion tap/_version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ['__version__']

# major, minor, patch
version_info = 1, 6, 1
version_info = 1, 6, 2

# Nice string for the version
__version__ = '.'.join(map(str, version_info))
73 changes: 36 additions & 37 deletions tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from warnings import warn
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
from typing_inspect import is_literal_type, get_args, get_origin, is_union_type
from typing_inspect import is_literal_type, get_args

from tap.utils import (
get_class_variables,
get_argument_name,
get_git_root,
get_dest,
get_git_url,
get_origin,
has_git,
has_uncommitted_changes,
is_option_arg,
Expand All @@ -32,16 +33,10 @@

# Constants
EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple()
BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple}
OPTIONAL_TYPES = {Optional, Union}
BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES

SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool}
SUPPORTED_DEFAULT_OPTIONAL_TYPES = {Optional, Optional[str], Optional[int], Optional[float], Optional[bool]}
SUPPORTED_DEFAULT_LIST_TYPES = {List, List[str], List[int], List[float], List[bool]}
SUPPORTED_DEFAULT_SET_TYPES = {Set, Set[str], Set[int], Set[float], Set[bool]}
SUPPORTED_DEFAULT_COLLECTION_TYPES = SUPPORTED_DEFAULT_LIST_TYPES | SUPPORTED_DEFAULT_SET_TYPES | {Tuple}
SUPPORTED_DEFAULT_BOXED_TYPES = SUPPORTED_DEFAULT_OPTIONAL_TYPES | SUPPORTED_DEFAULT_COLLECTION_TYPES
SUPPORTED_DEFAULT_TYPES = set.union(SUPPORTED_DEFAULT_BASE_TYPES,
SUPPORTED_DEFAULT_OPTIONAL_TYPES,
SUPPORTED_DEFAULT_COLLECTION_TYPES)

TapType = TypeVar('TapType', bound='Tap')

Expand Down Expand Up @@ -125,6 +120,9 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
:param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo.
:param kwargs: Keyword arguments.
"""
# Set explicit bool
explicit_bool = self._explicit_bool

# Get variable name
variable = get_argument_name(*name_or_flags)

Expand Down Expand Up @@ -168,6 +166,21 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:

# If type is not explicitly provided, set it if it's one of our supported default types
if 'type' not in kwargs:

# Unbox Optional[type] and set var_type = type
if get_origin(var_type) in OPTIONAL_TYPES:
var_args = get_args(var_type)

if len(var_args) > 0:
var_type = get_args(var_type)[0]

# If var_type is tuple as in Python 3.6, change to a typing type
# (e.g., (typing.List, <class 'bool'>) ==> typing.List[bool])
if isinstance(var_type, tuple):
var_type = var_type[0][var_type[1:]]

explicit_bool = True

# First check whether it is a literal type or a boxed literal type
if is_literal_type(var_type):
var_type, kwargs['choices'] = get_literals(var_type, variable)
Expand Down Expand Up @@ -195,27 +208,10 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
kwargs['nargs'] = len(types)

var_type = TupleTypeEnforcer(types=types, loop=loop)
# To identify an Optional type, check if it's a union of a None and something else
elif (
is_union_type(var_type)
and len(get_args(var_type)) == 2
and isinstance(None, get_args(var_type)[1])
and is_literal_type(get_args(var_type)[0])
):
var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable)
elif var_type not in SUPPORTED_DEFAULT_TYPES:
is_required = kwargs.get('required', False)
arg_params = 'required=True' if is_required else f'default={getattr(self, variable)}'
raise ValueError(
f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n'
f'Please explicitly add the argument to the parser by writing:\n\n'
f'def configure(self) -> None:\n'
f' self.add_argument("--{variable}", type=func, {arg_params})\n\n'
f'where "func" maps from str to {var_type}.')

if var_type in SUPPORTED_DEFAULT_BOXED_TYPES:

if get_origin(var_type) in BOXED_TYPES:
# If List or Set type, set nargs
if (var_type in SUPPORTED_DEFAULT_COLLECTION_TYPES
if (get_origin(var_type) in BOXED_COLLECTION_TYPES
and kwargs.get('action') not in {'append', 'append_const'}):
kwargs['nargs'] = kwargs.get('nargs', '*')

Expand All @@ -228,13 +224,12 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
else:
var_type = arg_types[0]

# Handle the cases of Optional[bool], List[bool], Set[bool]
# Handle the cases of List[bool], Set[bool], Tuple[bool]
if var_type == bool:
var_type = boolean_type

# If bool then set action, otherwise set type
if var_type == bool:
if self._explicit_bool:
if explicit_bool:
kwargs['type'] = boolean_type
kwargs['choices'] = [True, False] # this makes the help message more helpful
else:
Expand Down Expand Up @@ -404,10 +399,14 @@ def parse_args(self: TapType,
if type(value) == list:
var_type = get_origin(self._annotations[variable])

# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9
# https://github.com/ilevkivskyi/typing_inspect/issues/64
# https://github.com/ilevkivskyi/typing_inspect/issues/65
var_type = var_type if var_type is not None else self._annotations[variable]
# Unpack nested boxed types such as Optional[List[int]]
if var_type is Union:
var_type = get_origin(get_args(self._annotations[variable])[0])

# If var_type is tuple as in Python 3.6, change to a typing type
# (e.g., (typing.Tuple, <class 'bool'>) ==> typing.Tuple)
if isinstance(var_type, tuple):
var_type = var_type[0]

if var_type in (Set, set):
value = set(value)
Expand Down
15 changes: 14 additions & 1 deletion tap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Union,
)
from typing_extensions import Literal
from typing_inspect import get_args
from typing_inspect import get_args, get_origin as typing_inspect_get_origin


NO_CHANGES_STATUS = """nothing to commit, working tree clean"""
Expand Down Expand Up @@ -467,3 +467,16 @@ def enforce_reproducibility(saved_reproducibility_data: Optional[Dict[str, str]]
if current_reproducibility_data['git_has_uncommitted_changes']:
raise ValueError(f'{no_reproducibility_message}: Uncommitted changes '
f'in current args.')


# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8 and 3.9
# https://github.com/ilevkivskyi/typing_inspect/issues/64
# https://github.com/ilevkivskyi/typing_inspect/issues/65
def get_origin(tp: Any) -> Any:
"""Same as typing_inspect.get_origin but fixes unparameterized generic types like Set."""
origin = typing_inspect_get_origin(tp)

if origin is None:
origin = tp

return origin
172 changes: 157 additions & 15 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from copy import deepcopy
import os
from pathlib import Path
import sys
from tempfile import TemporaryDirectory
from typing import Any, List, Optional, Set, Tuple
from typing import Any, Iterable, List, Optional, Set, Tuple
from typing_extensions import Literal
import unittest
from unittest import TestCase

from tap import Tap


def stringify(arg_list: Iterable[Any]) -> List[str]:
"""Converts an iterable of arguments of any type to a list of strings.
:param arg_list: An iterable of arguments of any type.
:return: A list of the arguments as strings.
"""
return [str(arg) for arg in arg_list]


class EdgeCaseTests(TestCase):
def test_empty(self) -> None:
class EmptyTap(Tap):
Expand Down Expand Up @@ -112,17 +122,144 @@ def test_both_assigned_okay(self):
self.assertEqual(args.arg_list_str_required, ['hi', 'there'])


class CrashesOnUnsupportedTypesTests(TestCase):
# TODO: need to implement list[str] etc.
# class ParameterizedStandardCollectionTap(Tap):
# arg_list_str: list[str]
# arg_list_int: list[int]
# arg_list_int_default: list[int] = [1, 2, 5]
# arg_set_float: set[float]
# arg_set_str_default: set[str] = ['one', 'two', 'five']
# arg_tuple_int: tuple[int, ...]
# arg_tuple_float_default: tuple[float, float, float] = (1.0, 2.0, 5.0)
# arg_tuple_str_override: tuple[str, str] = ('hi', 'there')
# arg_optional_list_int: Optional[list[int]] = None


# class ParameterizedStandardCollectionTests(TestCase):
# @unittest.skipIf(sys.version_info < (3, 9), 'Parameterized standard collections (e.g., list[int]) introduced in Python 3.9')
# def test_parameterized_standard_collection(self):
# arg_list_str = ['a', 'b', 'pi']
# arg_list_int = [-2, -5, 10]
# arg_set_float = {3.54, 2.235}
# arg_tuple_int = (-4, 5, 9, 103)
# arg_tuple_str_override = ('why', 'so', 'many', 'tests?')
# arg_optional_list_int = [5, 4, 3]

# args = ParameterizedStandardCollectionTap().parse_args([
# '--arg_list_str', *arg_list_str,
# '--arg_list_int', *[str(var) for var in arg_list_int],
# '--arg_set_float', *[str(var) for var in arg_set_float],
# '--arg_tuple_int', *[str(var) for var in arg_tuple_int],
# '--arg_tuple_str_override', *arg_tuple_str_override,
# '--arg_optional_list_int', *[str(var) for var in arg_optional_list_int]
# ])

# self.assertEqual(args.arg_list_str, arg_list_str)
# self.assertEqual(args.arg_list_int, arg_list_int)
# self.assertEqual(args.arg_list_int_default, ParameterizedStandardCollectionTap.arg_list_int_default)
# self.assertEqual(args.arg_set_float, arg_set_float)
# self.assertEqual(args.arg_set_str_default, ParameterizedStandardCollectionTap.arg_set_str_default)
# self.assertEqual(args.arg_tuple_int, arg_tuple_int)
# self.assertEqual(args.arg_tuple_float_default, ParameterizedStandardCollectionTap.arg_tuple_float_default)
# self.assertEqual(args.arg_tuple_str_override, arg_tuple_str_override)
# self.assertEqual(args.arg_optional_list_int, arg_optional_list_int)


class NestedOptionalTypesTap(Tap):
list_bool: Optional[List[bool]]
list_int: Optional[List[int]]
list_str: Optional[List[str]]
set_bool: Optional[Set[bool]]
set_int: Optional[Set[int]]
set_str: Optional[Set[str]]
tuple_bool: Optional[Tuple[bool]]
tuple_int: Optional[Tuple[int]]
tuple_str: Optional[Tuple[str]]
tuple_pair: Optional[Tuple[bool, str, int]]
tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
tuple_arbitrary_len_str: Optional[Tuple[str, ...]]


class NestedOptionalTypeTests(TestCase):

def test_nested_optional_types(self):
list_bool = [True, False]
list_int = [0, 1, 2]
list_str = ['a', 'bee', 'cd', 'ee']
set_bool = {True, False, True}
set_int = {0, 1}
set_str = {'a', 'bee', 'cd'}
tuple_bool = (False,)
tuple_int = (0,)
tuple_str = ('a',)
tuple_pair = (False, 'a', 1)
tuple_arbitrary_len_bool = (True, False, False)
tuple_arbitrary_len_int = (1, 2, 3, 4)
tuple_arbitrary_len_str = ('a', 'b')

args = NestedOptionalTypesTap().parse_args([
'--list_bool', *stringify(list_bool),
'--list_int', *stringify(list_int),
'--list_str', *stringify(list_str),
'--set_bool', *stringify(set_bool),
'--set_int', *stringify(set_int),
'--set_str', *stringify(set_str),
'--tuple_bool', *stringify(tuple_bool),
'--tuple_int', *stringify(tuple_int),
'--tuple_str', *stringify(tuple_str),
'--tuple_pair', *stringify(tuple_pair),
'--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
'--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
'--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
])

def test_crashes_on_unsupported(self):
# From PiDelport: https://github.com/swansonk14/typed-argument-parser/issues/27
from pathlib import Path
self.assertEqual(args.list_bool, list_bool)
self.assertEqual(args.list_int, list_int)
self.assertEqual(args.list_str, list_str)

self.assertEqual(args.set_bool, set_bool)
self.assertEqual(args.set_int, set_int)
self.assertEqual(args.set_str, set_str)

self.assertEqual(args.tuple_bool, tuple_bool)
self.assertEqual(args.tuple_int, tuple_int)
self.assertEqual(args.tuple_str, tuple_str)
self.assertEqual(args.tuple_pair, tuple_pair)
self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)


class ComplexTypeTap(Tap):
path: Path
optional_path: Optional[Path]
list_path: List[Path]
set_path: Set[Path]
tuple_path: Tuple[Path, Path]


class ComplexTypeTests(TestCase):
def test_complex_types(self):
path = Path('/path/to/file.txt')
optional_path = Path('/path/to/optional/file.txt')
list_path = [Path('/path/to/list/file1.txt'), Path('/path/to/list/file2.txt')]
set_path = {Path('/path/to/set/file1.txt'), Path('/path/to/set/file2.txt')}
tuple_path = (Path('/path/to/tuple/file1.txt'), Path('/path/to/tuple/file2.txt'))

args = ComplexTypeTap().parse_args([
'--path', str(path),
'--optional_path', str(optional_path),
'--list_path', *[str(path) for path in list_path],
'--set_path', *[str(path) for path in set_path],
'--tuple_path', *[str(path) for path in tuple_path]
])

class CrashingArgumentParser(Tap):
some_path: Path = 'some_path'

with self.assertRaises(ValueError):
CrashingArgumentParser().parse_args([])
self.assertEqual(args.path, path)
self.assertEqual(args.optional_path, optional_path)
self.assertEqual(args.list_path, list_path)
self.assertEqual(args.set_path, set_path)
self.assertEqual(args.tuple_path, tuple_path)


class Person:
Expand Down Expand Up @@ -312,7 +449,6 @@ def test_set_default_args(self) -> None:
'--arg_list_bool', *arg_list_bool,
'--arg_list_str_empty', *arg_list_str_empty,
'--arg_list_literal', *arg_list_literal,

'--arg_set', *arg_set,
'--arg_set_str', *arg_set_str,
'--arg_set_int', *arg_set_int,
Expand Down Expand Up @@ -496,26 +632,32 @@ def configure(self) -> None:
def test_complex_type(self) -> None:
class AddArgumentComplexTypeTap(IntegrationDefaultTap):
arg_person: Person = Person('tap')
# arg_person_required: Person # TODO
arg_person_required: Person
arg_person_untyped = Person('tap untyped')

# TODO: assert a crash if any complex types are not explicitly added in add_argument
def configure(self) -> None:
self.add_argument('--arg_person', type=Person)
# self.add_argument('--arg_person_required', type=Person) # TODO
self.add_argument('--arg_person_required', type=Person)
self.add_argument('--arg_person_untyped', type=Person)

args = AddArgumentComplexTypeTap().parse_args([])
arg_person_required = Person("hello, it's me")

args = AddArgumentComplexTypeTap().parse_args([
'--arg_person_required', arg_person_required.name,
])
self.assertEqual(args.arg_person, Person('tap'))
self.assertEqual(args.arg_person_required, arg_person_required)
self.assertEqual(args.arg_person_untyped, Person('tap untyped'))

arg_person = Person('hi there')
arg_person_untyped = Person('heyyyy')
args = AddArgumentComplexTypeTap().parse_args([
'--arg_person', arg_person.name,
'--arg_person_required', arg_person_required.name,
'--arg_person_untyped', arg_person_untyped.name
])
self.assertEqual(args.arg_person, arg_person)
self.assertEqual(args.arg_person_required, arg_person_required)
self.assertEqual(args.arg_person_untyped, arg_person_untyped)

def test_repeat_default(self) -> None:
Expand Down

0 comments on commit b51d06c

Please sign in to comment.