Skip to content

Commit

Permalink
enable dataclass functions
Browse files Browse the repository at this point in the history
  • Loading branch information
xyluo25 committed Aug 23, 2024
1 parent b0e74cc commit b0c9bf9
Show file tree
Hide file tree
Showing 2 changed files with 477 additions and 229 deletions.
61 changes: 52 additions & 9 deletions pyufunc/util_data_processing/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
##############################################################
'''

from dataclasses import dataclass, field, fields, make_dataclass, MISSING, is_dataclass
from dataclasses import dataclass, field, fields, make_dataclass, MISSING, is_dataclass, asdict
from typing import Any, List, Tuple, Type, Union, Dict
import copy


def create_dataclass_from_dict(name: str, data: Dict[str, Any]) -> Type:
Expand Down Expand Up @@ -37,16 +38,27 @@ def __setitem__(self, key, value):
else:
raise KeyError(f"Key {key} not found in {self.__class__.__name__}")

# Define a method to convert the dataclass to a dictionary
def as_dict(self):
return asdict(self)

# Extract fields and their types from the dictionary
fields = [(key, type(value), field(default=value))
for key, value in data.items()]
dataclass_fields = []
for key, value in data.items():
if isinstance(value, (list, dict, set)): # For mutable types
dataclass_fields.append(
(key, type(value), field(default_factory=lambda v=value: v)))
else: # For immutable types
dataclass_fields.append((key, type(value), field(default=value)))

# Create the dataclass dynamically
DataClass = make_dataclass(
cls_name=name,
fields=fields,
fields=dataclass_fields,
bases=(),
namespace={'__getitem__': __getitem__, '__setitem__': __setitem__}
namespace={'__getitem__': __getitem__,
'__setitem__': __setitem__,
'as_dict': as_dict}
)

# Instantiate the dataclass with the values from the dictionary
Expand Down Expand Up @@ -95,6 +107,10 @@ def __setitem__(self, key, value):
setattr(self, key, value)
else:
raise KeyError(f"Key {key} not found in {self.__class__.__name__}")

def as_dict(self):
return asdict(self)

processed_attributes = []

for attr in attributes:
Expand All @@ -106,7 +122,9 @@ def __setitem__(self, key, value):
processed_attributes.append((attr[0], attr[1], attr[2]))

return make_dataclass(class_name, processed_attributes,
namespace={'__getitem__': __getitem__, '__setitem__': __setitem__})
namespace={'__getitem__': __getitem__,
'__setitem__': __setitem__,
'as_dict': as_dict})


def merge_dataclass(dataclass_one: Type[Any], dataclass_two: Type[Any],
Expand Down Expand Up @@ -185,6 +203,21 @@ def extend_dataclass(
additional_attributes (list): A list of tuples in the form (name, type, default_value).
or (name, default_value) to add to the base dataclass.
Example:
>>> from dataclasses import dataclass
>>> from typing import List
>>> from pyufunc import extend_dataclass
>>> @dataclass
... class BaseDataclass:
... name: str = 'base'
>>> ExtendedDataclass = extend_dataclass(
... base_dataclass=BaseDataclass,
... additional_attributes=[('new_attr', List[int], [1, 2, 3])])
>>> ExtendedDataclass
Returns:
dataclass: A new dataclass that includes fields from base_dataclass and additional_attributes.
"""
Expand All @@ -198,9 +231,13 @@ def extend_dataclass(
raise ValueError('additional_attributes must be a list of tuples'
' in the form (name, default_value) or (name, type, default_value)')

# deepcopy the base dataclass
base_dataclass_ = copy.deepcopy(base_dataclass)
# base_dataclass_ = base_dataclass

# Extract existing fields from the base dataclass
base_fields = []
for f in fields(base_dataclass):
for f in fields(base_dataclass_):
if f.default is not MISSING:
base_fields.append((f.name, f.type, f.default))
elif f.default_factory is not MISSING:
Expand All @@ -220,12 +257,18 @@ def extend_dataclass(
# Combine base fields with additional attributes
all_fields = base_fields + additional_attributes

return make_dataclass(
cls_name=f'{base_dataclass.__name__}',
new_dataclass = make_dataclass(
cls_name=f'{base_dataclass_.__name__}',
fields=all_fields,
bases=(base_dataclass,),
)

# Register the new dataclass in the global scope to allow pickling
globals()[new_dataclass.__name__] = new_dataclass

# new_dataclass.__module__ = base_dataclass_.__module__
return new_dataclass


def dataclass_dict_access(dataclass_instance: Any) -> Any:
"""Wrap a dataclass instance to provide dictionary-like access.
Expand Down
Loading

0 comments on commit b0c9bf9

Please sign in to comment.