From 04c248b2884d0a8a3e2997437f90428a62ee3bfa Mon Sep 17 00:00:00 2001 From: Matthew Larson Date: Wed, 8 May 2024 17:07:36 -0500 Subject: [PATCH] Support track order on group.get() --- h5pyd/_hl/attrs.py | 40 +++++++++++++++++++++++++++++++- h5pyd/_hl/dataset.py | 3 ++- h5pyd/_hl/group.py | 54 ++++++++++++++++++++++++++++--------------- test/hl/test_group.py | 50 +++++++++++++++++++++++++++++++++++---- 4 files changed, 123 insertions(+), 24 deletions(-) diff --git a/h5pyd/_hl/attrs.py b/h5pyd/_hl/attrs.py index 35327e0..e6c5905 100644 --- a/h5pyd/_hl/attrs.py +++ b/h5pyd/_hl/attrs.py @@ -343,8 +343,16 @@ def __len__(self): def __iter__(self): """ Iterate over the names of attributes. """ if self._objdb_attributes is not None: + if self._parent._track_order: + attrs = sorted(self._objdb_attributes.items(), key=lambda x: x[1]['created']) + else: + attrs = sorted(self._objdb_attributes.items()) + + ordered_attrs = {} + for a in attrs: + ordered_attrs[a[0]] = a[1] - for name in self._objdb_attributes: + for name in ordered_attrs: yield name else: @@ -384,3 +392,33 @@ def __repr__(self): if not self._parent.id.id: return "" return "" % id(self._parent.id) + + def __reversed__(self): + """ Iterate over the names of attributes in reverse order. """ + if self._objdb_attributes is not None: + if self._parent._track_order: + attrs = sorted(self._objdb_attributes.items(), key=lambda x: x[1]['created']) + else: + attrs = sorted(self._objdb_attributes.items()) + + ordered_attrs = {} + for a in attrs: + ordered_attrs[a[0]] = a[1] + + for name in reversed(ordered_attrs): + yield name + + else: + # make server request + req = self._req_prefix + # backup over the trailing slash in req + req = req[:-1] + rsp = self._parent.GET(req, params={"CreateOrder": "1" if self._parent._track_order else "0"}) + attributes = rsp['attributes'] + + attrlist = [] + for attr in attributes: + attrlist.append(attr['name']) + + for name in reversed(attrlist): + yield name diff --git a/h5pyd/_hl/dataset.py b/h5pyd/_hl/dataset.py index 992203c..9d792d4 100644 --- a/h5pyd/_hl/dataset.py +++ b/h5pyd/_hl/dataset.py @@ -717,7 +717,7 @@ def allocated_size(self): self._getVerboseInfo() return self._allocated_size - def __init__(self, bind): + def __init__(self, bind, track_order=False): """Create a new Dataset object by binding to a low-level DatasetID.""" if not isinstance(bind, DatasetID): @@ -732,6 +732,7 @@ def __init__(self, bind): # make a numpy dtype out of the type json self._dtype = createDataType(self.id.type_json) self._item_size = getItemSize(self.id.type_json) + self._track_order = track_order self._shape = self.get_shape() diff --git a/h5pyd/_hl/group.py b/h5pyd/_hl/group.py index e0c5ba8..e999ac6 100644 --- a/h5pyd/_hl/group.py +++ b/h5pyd/_hl/group.py @@ -550,7 +550,7 @@ def require_group(self, name): raise TypeError("Incompatible object (%s) already exists" % grp.__class__.__name__) return grp - def getObjByUuid(self, uuid, collection_type=None): + def getObjByUuid(self, uuid, collection_type=None, track_order=False): """ Utility method to get an obj based on collection type and uuid """ self.log.debug(f"getObjByUuid({uuid})") obj_json = None @@ -585,10 +585,10 @@ def getObjByUuid(self, uuid, collection_type=None): # will need to get JSON from server req = f"/{collection_type}/{uuid}" # make server request - obj_json = self.GET(req, params={"CreateOrder": "1" if self._track_order else "0"}) + obj_json = self.GET(req, params={"CreateOrder": "1" if track_order else "0"}) if collection_type == 'groups': - tgt = Group(GroupID(self, obj_json)) + tgt = Group(GroupID(self, obj_json), track_order=track_order) elif collection_type == 'datatypes': tgt = Datatype(TypeID(self, obj_json)) elif collection_type == 'datasets': @@ -598,13 +598,13 @@ def getObjByUuid(self, uuid, collection_type=None): if "dims" in shape_json and len(shape_json["dims"]) == 1 and dtype_json["class"] == 'H5T_COMPOUND': tgt = Table(DatasetID(self, obj_json)) else: - tgt = Dataset(DatasetID(self, obj_json)) + tgt = Dataset(DatasetID(self, obj_json), track_order=track_order) else: raise IOError(f"Unexpected collection_type: {collection_type}") return tgt - def __getitem__(self, name): + def __getitem__(self, name, track_order=False): """ Open an object in the file """ # convert bytes to str for PY3 if isinstance(name, bytes): @@ -617,11 +617,11 @@ def __getitem__(self, name): if tgt is not None: return tgt # ref'd object has not been deleted if isinstance(name.id, GroupID): - tgt = self.getObjByUuid(name.id.uuid, collection_type="groups") + tgt = self.getObjByUuid(name.id.uuid, collection_type="groups", track_order=track_order) elif isinstance(name.id, DatasetID): - tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets") + tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets", track_order=track_order) elif isinstance(name.id, TypeID): - tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets") + tgt = self.getObjByUuid(name.id.uuid, collection_type="datasets", track_order=track_order) else: raise IOError("Unexpected Error - ObjectID type: " + name.__class__.__name__) return tgt @@ -634,11 +634,11 @@ def __getitem__(self, name): link_class = link_json['class'] if link_class == 'H5L_TYPE_HARD': - tgt = self.getObjByUuid(link_json['id'], collection_type=link_json['collection']) + tgt = self.getObjByUuid(link_json['id'], collection_type=link_json['collection'], track_order=track_order) elif link_class == 'H5L_TYPE_SOFT': h5path = link_json['h5path'] soft_parent_uuid, soft_json = self._get_link_json(h5path) - tgt = self.getObjByUuid(soft_json['id'], collection_type=soft_json['collection']) + tgt = self.getObjByUuid(soft_json['id'], collection_type=soft_json['collection'], track_order=track_order) elif link_class == 'H5L_TYPE_EXTERNAL': # try to get a handle to the file and return the linked object... @@ -654,7 +654,8 @@ def __getitem__(self, name): endpoint = self.id.http_conn.endpoint username = self.id.http_conn.username password = self.id.http_conn.password - f = File(external_domain, endpoint=endpoint, username=username, password=password, mode='r') + f = File(external_domain, endpoint=endpoint, username=username, password=password, mode='r', + track_order=track_order) except IOError: # unable to find external link raise KeyError("Unable to open file: " + link_json['h5domain']) @@ -678,7 +679,7 @@ def __getitem__(self, name): tgt._name = name return tgt - def get(self, name, default=None, getclass=False, getlink=False): + def get(self, name, default=None, getclass=False, getlink=False, track_order=False): """ Retrieve an item or other information. "name" given only: @@ -702,10 +703,9 @@ def get(self, name, default=None, getclass=False, getlink=False): >>> if cls == SoftLink: ... print '"foo" is a soft link!' """ - if not (getclass or getlink): try: - return self[name] + return self.__getitem__(name, track_order) except KeyError: return default @@ -713,7 +713,7 @@ def get(self, name, default=None, getclass=False, getlink=False): return default elif getclass and not getlink: - obj = self.__getitem__(name) + obj = self.__getitem__(name, track_order) if obj is None: return None if obj.id.__class__ is GroupID: @@ -891,7 +891,16 @@ def __iter__(self): for x in links: yield x['title'] else: - for name in links: + if self._track_order: + links = sorted(links.items(), key=lambda x: x[1]['created']) + else: + links = sorted(links.items()) + + ordered_links = {} + for link in links: + ordered_links[link[0]] = link[1] + + for name in ordered_links: yield name def __contains__(self, name): @@ -1151,14 +1160,23 @@ def __reversed__(self): # reset the link cache self._link_db = {} - for link in reversed(links): + for link in links: name = link["title"] self._link_db[name] = link for x in reversed(links): yield x['title'] else: - for name in links: + if self._track_order: + links = sorted(links.items(), key=lambda x: x[1]['created']) + else: + links = sorted(links.items()) + + ordered_links = {} + for link in links: + ordered_links[link[0]] = link[1] + + for name in reversed(ordered_links): yield name diff --git a/test/hl/test_group.py b/test/hl/test_group.py index e2a3be8..6adebc6 100644 --- a/test/hl/test_group.py +++ b/test/hl/test_group.py @@ -101,6 +101,8 @@ def test_create(self): self.assertTrue(False) # shouldn't get here' except RuntimeError: pass # expected + except OSError: + pass # also acceptable del r['tmp'] self.assertEqual(len(r), 4) @@ -338,19 +340,27 @@ def test_no_track_order(self): self.assertEqual(list(reversed(g)), list(reversed(ref))) def test_get_dataset_track_order(self): + + # h5py does not support track_order on group.get() + if config.get("use_h5py"): + return + filename = self.getFileName("test_get_dataset_track_order") print(f"filename: {filename}") self.f = h5py.File(filename, 'w') g = self.f.create_group('order') - # create dataset, close file, re-open and get dataset + dset = g.create_dataset('dset', (10,), dtype='i4') dset2 = g.create_dataset('dset2', (10,), dtype='i4') + self.populate_attrs(dset) + self.populate_attrs(dset2) + self.f.close() - f = h5py.File(filename, 'r') - g = f['order'] - d = g.get('dset', track_order=True) + self.f = h5py.File(filename, 'r') + g = self.f['order'] + d = g.get('dset', track_order=True) ref = [str(i) for i in range(100)] self.assertEqual(list(d.attrs), ref) self.assertEqual(list(reversed(d.attrs)), list(reversed(ref))) @@ -360,6 +370,38 @@ def test_get_dataset_track_order(self): self.assertEqual(list(d2.attrs), ref) self.assertEqual(list(reversed(d2.attrs)), list(reversed(ref))) + def test_get_group_track_order(self): + + # h5py does not support track_order on group.get() + if config.get("use_h5py"): + return + filename = self.getFileName("test_get_group_track_order") + print(f"filename: {filename}") + self.f = h5py.File(filename, 'w') + g = self.f.create_group('order') + + # create subgroup and populate it with links + g.create_group('subgroup') + self.populate(g['subgroup']) + + self.f.close() + self.f = h5py.File(filename, 'r') + g = self.f['order'] + + subg = g.get('subgroup', track_order=True) + ref = [str(i) for i in range(100)] + self.assertEqual(list(subg), ref) + self.assertEqual(list(reversed(subg)), list(reversed(ref))) + + self.f.close() + self.f = h5py.File(filename, 'r') + g = self.f['order'] + subg2 = g.get('subgroup', track_order=False) + ref = sorted([str(i) for i in range(100)]) + self.assertEqual(list(subg2), ref) + self.assertEqual(list(reversed(subg2)), list(reversed(ref))) + + if __name__ == '__main__': loglevel = logging.ERROR logging.basicConfig(format='%(asctime)s %(message)s', level=loglevel)