Skip to content

Commit

Permalink
spacing
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristosT committed Jan 16, 2025
1 parent b9b798f commit 834ad77
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions hexrd/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import fields
import json
import copy
from typing import get_origin,get_args,Union
from typing import get_origin, get_args, Union

import argparse

Expand Down Expand Up @@ -70,7 +70,7 @@ def add_profile_subparser(
default_value = field.default_factory()
else:
default_value = field.default
tp,default_value = _get_supported_type(field.type,default_value)
tp, default_value = _get_supported_type(field.type, default_value)
# fields with default value = None are treated as posiitonal
if default_value is not None:
switches = [f"--{field.name}"]
Expand Down Expand Up @@ -113,9 +113,9 @@ def _remove_non_dataclass_args(args_dict: dict) -> tuple[dict, dict]:

# these are defined in main.py
# if we ever add more we will need to update this list
hexrd_app_args = ['debug', 'inst_profile', 'cmd', 'func']
hexrd_app_args = ['debug', 'inst_profile', 'cmd', 'func']
for key in hexrd_app_args:
del args[key]
del args[key]

# extra are added by the preprocess subparser
extra = {}
Expand All @@ -125,21 +125,23 @@ def _remove_non_dataclass_args(args_dict: dict) -> tuple[dict, dict]:
del args[key]
return args, extra

def _get_supported_type(tp,default_value=None):
""" Replace any type not supported by argparse in the command line with an
alternative. Also, return the new default value in the appropriate format.
For now we just replace dictionaries with json strings this
allows to pass a dict as '{"key1":value1, "key2":value2}'
"""
# second condition is required in case the dataclass field is defined using
# members of the typing module.
if tp is dict or get_origin(tp) is dict:
return json.loads,f"'{json.dumps(default_value)}'"
elif is_optional(tp):
return get_args(tp)[0],None
else:
return tp,default_value

def _get_supported_type(tp, default_value=None):
"""Replace any type not supported by argparse in the command line with an
alternative. Also, return the new default value in the appropriate format.
For now we just replace dictionaries with json strings this
allows to pass a dict as '{"key1":value1, "key2":value2}'
"""
# second condition is required in case the dataclass field is defined using
# members of the typing module.
if tp is dict or get_origin(tp) is dict:
return json.loads, f"'{json.dumps(default_value)}'"
elif is_optional(tp):
return get_args(tp)[0], None
else:
return tp, default_value


def is_optional(field):
return get_origin(field) is Union and type(None) in get_args(field)

0 comments on commit 834ad77

Please sign in to comment.