Skip to content

Commit

Permalink
Support track order on group.get()
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjala committed May 9, 2024
1 parent bcecf25 commit 380e1d7
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 24 deletions.
40 changes: 39 additions & 1 deletion h5pyd/_hl/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,16 @@ def __len__(self):
def __iter__(self):
""" Iterate over the names of attributes. """
if self._objdb_attributes is not None:
if self._parent._track_order:
attrs = sorted(self._objdb_attributes.items(), key=lambda x: x[1]['created'])
else:
attrs = sorted(self._objdb_attributes.items())

ordered_attrs = {}
for a in attrs:
ordered_attrs[a[0]] = a[1]

for name in self._objdb_attributes:
for name in ordered_attrs:
yield name

else:
Expand Down Expand Up @@ -384,3 +392,33 @@ def __repr__(self):
if not self._parent.id.id:
return "<Attributes of closed HDF5 object>"
return "<Attributes of HDF5 object at %s>" % id(self._parent.id)

def __reversed__(self):
""" Iterate over the names of attributes in reverse order. """
if self._objdb_attributes is not None:
if self._parent._track_order:
attrs = sorted(self._objdb_attributes.items(), key=lambda x: x[1]['created'])
else:
attrs = sorted(self._objdb_attributes.items())

ordered_attrs = {}
for a in attrs:
ordered_attrs[a[0]] = a[1]

for name in reversed(ordered_attrs):
yield name

else:
# make server request
req = self._req_prefix
# backup over the trailing slash in req
req = req[:-1]
rsp = self._parent.GET(req, params={"CreateOrder": "1" if self._parent._track_order else "0"})
attributes = rsp['attributes']

attrlist = []
for attr in attributes:
attrlist.append(attr['name'])

for name in reversed(attrlist):
yield name
3 changes: 2 additions & 1 deletion h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def allocated_size(self):
self._getVerboseInfo()
return self._allocated_size

def __init__(self, bind):
def __init__(self, bind, track_order=False):
"""Create a new Dataset object by binding to a low-level DatasetID."""

if not isinstance(bind, DatasetID):
Expand All @@ -732,6 +732,7 @@ def __init__(self, bind):
# make a numpy dtype out of the type json
self._dtype = createDataType(self.id.type_json)
self._item_size = getItemSize(self.id.type_json)
self._track_order = track_order

self._shape = self.get_shape()

Expand Down
54 changes: 36 additions & 18 deletions h5pyd/_hl/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def require_group(self, name):
raise TypeError("Incompatible object (%s) already exists" % grp.__class__.__name__)
return grp

def getObjByUuid(self, uuid, collection_type=None):
def getObjByUuid(self, uuid, collection_type=None, track_order=False):
""" Utility method to get an obj based on collection type and uuid """
self.log.debug(f"getObjByUuid({uuid})")
obj_json = None
Expand Down Expand Up @@ -585,10 +585,10 @@ def getObjByUuid(self, uuid, collection_type=None):
# will need to get JSON from server
req = f"/{collection_type}/{uuid}"
# make server request
obj_json = self.GET(req, params={"CreateOrder": "1" if self._track_order else "0"})
obj_json = self.GET(req, params={"CreateOrder": "1" if track_order else "0"})

if collection_type == 'groups':
tgt = Group(GroupID(self, obj_json))
tgt = Group(GroupID(self, obj_json), track_order=track_order)
elif collection_type == 'datatypes':
tgt = Datatype(TypeID(self, obj_json))
elif collection_type == 'datasets':
Expand All @@ -598,13 +598,13 @@ def getObjByUuid(self, uuid, collection_type=None):
if "dims" in shape_json and len(shape_json["dims"]) == 1 and dtype_json["class"] == 'H5T_COMPOUND':
tgt = Table(DatasetID(self, obj_json))
else:
tgt = Dataset(DatasetID(self, obj_json))
tgt = Dataset(DatasetID(self, obj_json), track_order=track_order)
else:
raise IOError(f"Unexpected collection_type: {collection_type}")

return tgt

def __getitem__(self, name):
def __getitem__(self, name, track_order=False):
""" Open an object in the file """
# convert bytes to str for PY3
if isinstance(name, bytes):
Expand All @@ -617,11 +617,11 @@ def __getitem__(self, name):
if tgt is not None:
return tgt # ref'd object has not been deleted
if isinstance(name.id, GroupID):
tgt = self.getObjByUuid(name.id.uuid, collection_type="groups")
tgt = self.getObjByUuid(name.id.uuid, collection_type="groups", track_order=track_order)
elif isinstance(name.id, DatasetID):
tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets")
tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets", track_order=track_order)
elif isinstance(name.id, TypeID):
tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets")
tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets", track_order=track_order)
else:
raise IOError("Unexpected Error - ObjectID type: " + name.__class__.__name__)
return tgt
Expand All @@ -634,11 +634,11 @@ def __getitem__(self, name):
link_class = link_json['class']

if link_class == 'H5L_TYPE_HARD':
tgt = self.getObjByUuid(link_json['id'], collection_type=link_json['collection'])
tgt = self.getObjByUuid(link_json['id'], collection_type=link_json['collection'], track_order=track_order)
elif link_class == 'H5L_TYPE_SOFT':
h5path = link_json['h5path']
soft_parent_uuid, soft_json = self._get_link_json(h5path)
tgt = self.getObjByUuid(soft_json['id'], collection_type=soft_json['collection'])
tgt = self.getObjByUuid(soft_json['id'], collection_type=soft_json['collection'], track_order=track_order)

elif link_class == 'H5L_TYPE_EXTERNAL':
# try to get a handle to the file and return the linked object...
Expand All @@ -654,7 +654,8 @@ def __getitem__(self, name):
endpoint = self.id.http_conn.endpoint
username = self.id.http_conn.username
password = self.id.http_conn.password
f = File(external_domain, endpoint=endpoint, username=username, password=password, mode='r')
f = File(external_domain, endpoint=endpoint, username=username, password=password, mode='r',
track_order=track_order)
except IOError:
# unable to find external link
raise KeyError("Unable to open file: " + link_json['h5domain'])
Expand All @@ -678,7 +679,7 @@ def __getitem__(self, name):
tgt._name = name
return tgt

def get(self, name, default=None, getclass=False, getlink=False):
def get(self, name, default=None, getclass=False, getlink=False, track_order=False):
""" Retrieve an item or other information.
"name" given only:
Expand All @@ -702,18 +703,17 @@ def get(self, name, default=None, getclass=False, getlink=False):
>>> if cls == SoftLink:
... print '"foo" is a soft link!'
"""

if not (getclass or getlink):
try:
return self[name]
return self.__getitem__(name, track_order)
except KeyError:
return default

if name not in self:
return default

elif getclass and not getlink:
obj = self.__getitem__(name)
obj = self.__getitem__(name, track_order)
if obj is None:
return None
if obj.id.__class__ is GroupID:
Expand Down Expand Up @@ -891,7 +891,16 @@ def __iter__(self):
for x in links:
yield x['title']
else:
for name in links:
if self._track_order:
links = sorted(links.items(), key=lambda x: x[1]['created'])
else:
links = sorted(links.items())

ordered_links = {}
for link in links:
ordered_links[link[0]] = link[1]

for name in ordered_links:
yield name

def __contains__(self, name):
Expand Down Expand Up @@ -1151,14 +1160,23 @@ def __reversed__(self):

# reset the link cache
self._link_db = {}
for link in reversed(links):
for link in links:
name = link["title"]
self._link_db[name] = link

for x in reversed(links):
yield x['title']
else:
for name in links:
if self._track_order:
links = sorted(links.items(), key=lambda x: x[1]['created'])
else:
links = sorted(links.items())

ordered_links = {}
for link in links:
ordered_links[link[0]] = link[1]

for name in reversed(ordered_links):
yield name


Expand Down
49 changes: 45 additions & 4 deletions test/hl/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def test_create(self):
self.assertTrue(False) # shouldn't get here'
except RuntimeError:
pass # expected
except OSError:
pass # also acceptable

del r['tmp']
self.assertEqual(len(r), 4)
Expand Down Expand Up @@ -338,19 +340,27 @@ def test_no_track_order(self):
self.assertEqual(list(reversed(g)), list(reversed(ref)))

def test_get_dataset_track_order(self):

# h5py does not support track_order on group.get()
if config.get("use_h5py"):
return

filename = self.getFileName("test_get_dataset_track_order")
print(f"filename: {filename}")
self.f = h5py.File(filename, 'w')
g = self.f.create_group('order')
# create dataset, close file, re-open and get dataset

dset = g.create_dataset('dset', (10,), dtype='i4')
dset2 = g.create_dataset('dset2', (10,), dtype='i4')

self.populate_attrs(dset)
self.populate_attrs(dset2)

self.f.close()
f = h5py.File(filename, 'r')
g = f['order']
d = g.get('dset', track_order=True)
self.f = h5py.File(filename, 'r')
g = self.f['order']

d = g.get('dset', track_order=True)
ref = [str(i) for i in range(100)]
self.assertEqual(list(d.attrs), ref)
self.assertEqual(list(reversed(d.attrs)), list(reversed(ref)))
Expand All @@ -360,6 +370,37 @@ def test_get_dataset_track_order(self):
self.assertEqual(list(d2.attrs), ref)
self.assertEqual(list(reversed(d2.attrs)), list(reversed(ref)))

def test_get_group_track_order(self):
# h5py does not support track_order on group.get()
if config.get("use_h5py"):
return
filename = self.getFileName("test_get_group_track_order")
print(f"filename: {filename}")
self.f = h5py.File(filename, 'w')
g = self.f.create_group('order')

# create subgroup and populate it with links
g.create_group('subgroup')
self.populate(g['subgroup'])

self.f.close()
self.f = h5py.File(filename, 'r')
g = self.f['order']

subg = g.get('subgroup', track_order=True)
ref = [str(i) for i in range(100)]
self.assertEqual(list(subg), ref)
self.assertEqual(list(reversed(subg)), list(reversed(ref)))

self.f.close()
self.f = h5py.File(filename, 'r')
g = self.f['order']
subg2 = g.get('subgroup', track_order=False)
ref = sorted([str(i) for i in range(100)])
self.assertEqual(list(subg2), ref)
self.assertEqual(list(reversed(subg2)), list(reversed(ref)))


if __name__ == '__main__':
loglevel = logging.ERROR
logging.basicConfig(format='%(asctime)s %(message)s', level=loglevel)
Expand Down

0 comments on commit 380e1d7

Please sign in to comment.