Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add fo.SampleReference as a way to reference other samples in other datasets #5277

Draft
wants to merge 22 commits into
base: develop
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion fiftyone/__public__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
Run,
RunResults,
)
from .core.sample import Sample
from .core.sample import Sample, SampleReference
from .core.threed import (
BoxGeometry,
CylinderGeometry,
Expand Down
89 changes: 29 additions & 60 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import fiftyone as fo
import fiftyone.constants as focn
import fiftyone.core.collections as foc
from fiftyone.core.dataset_helpers import _create_frame_document_cls, _create_sample_document_cls, _set_field_read_only
import fiftyone.core.expressions as foe
from fiftyone.core.expressions import ViewField as F
import fiftyone.core.fields as fof
Expand Down Expand Up @@ -306,12 +307,14 @@ class Dataset(foc.SampleCollection, metaclass=DatasetSingleton):
"_evaluation_cache",
"_run_cache",
"_deleted",
"_reference",
)

def __init__(
self,
name=None,
persistent=False,
reference=None,
overwrite=False,
_create=True,
_virtual=False,
Expand All @@ -325,17 +328,19 @@ def __init__(

if _create:
doc, sample_doc_cls, frame_doc_cls = _create_dataset(
self, name, persistent=persistent, **kwargs
self, name, reference, persistent=persistent, **kwargs
)
else:
doc, sample_doc_cls, frame_doc_cls = _load_dataset(
self, name, virtual=_virtual
self, name, reference, virtual=_virtual
)

self._doc = doc
self._sample_doc_cls = sample_doc_cls
self._frame_doc_cls = frame_doc_cls

self._reference = reference

self._group_slice = doc.default_group_slice

self._annotation_cache = cachetools.LRUCache(5)
Expand Down Expand Up @@ -1384,7 +1389,7 @@ def get_field_schema(
a dict mapping field names to :class:`fiftyone.core.fields.Field`
instances
"""
return self._sample_doc_cls.get_field_schema(
base_schema = self._sample_doc_cls.get_field_schema(
ftype=ftype,
embedded_doc_type=embedded_doc_type,
read_only=read_only,
Expand All @@ -1395,6 +1400,18 @@ def get_field_schema(
mode=mode,
)

if self._reference:
reference_schema = self._reference.get_field_schema()

del reference_schema["id"]
del reference_schema["created_at"]
del reference_schema["last_modified_at"]

base_schema.update(reference_schema)

return base_schema


def get_frame_field_schema(
self,
ftype=None,
Expand Down Expand Up @@ -7888,6 +7905,8 @@ def _expand_frame_schema(self, frames, dynamic):

def _make_sample(self, d):
doc = self._sample_dict_to_doc(d)
if self._reference:
return fos.SampleReference.from_doc(doc, dataset=self)
return fos.Sample.from_doc(doc, dataset=self)

def _sample_dict_to_doc(self, d):
Expand Down Expand Up @@ -8020,7 +8039,7 @@ def _reload(self, hard=False):
return

doc, sample_doc_cls, frame_doc_cls = _load_dataset(
self, self.name, virtual=True
self, self.name, self._reference, virtual=True
)

new_media_type = doc.media_type != self.media_type
Expand Down Expand Up @@ -8116,6 +8135,7 @@ def _list_datasets_query(include_private=False, glob_patt=None, tags=None):
def _create_dataset(
obj,
name,
reference,
persistent=False,
_patches=False,
_frames=False,
Expand All @@ -8130,7 +8150,7 @@ def _create_dataset(
sample_collection_name = _make_sample_collection_name(
_id, patches=_patches, frames=_frames, clips=_clips
)
sample_doc_cls = _create_sample_document_cls(obj, sample_collection_name)
sample_doc_cls = _create_sample_document_cls(obj, sample_collection_name, reference)

# pylint: disable=no-member
sample_fields = [
Expand Down Expand Up @@ -8308,50 +8328,6 @@ def _make_frame_collection_name(sample_collection_name):
return "frames." + sample_collection_name


def _create_sample_document_cls(
dataset, sample_collection_name, field_docs=None
):
cls = type(sample_collection_name, (foo.DatasetSampleDocument,), {})
cls._dataset = dataset

_declare_fields(dataset, cls, field_docs=field_docs)
return cls


def _create_frame_document_cls(
dataset, frame_collection_name, field_docs=None
):
cls = type(frame_collection_name, (foo.DatasetFrameDocument,), {})
cls._dataset = dataset

_declare_fields(dataset, cls, field_docs=field_docs)
return cls


def _declare_fields(dataset, doc_cls, field_docs=None):
default_fields = set(doc_cls._fields.keys())
if field_docs is not None:
default_fields -= {field_doc.name for field_doc in field_docs}

# Declare default fields that don't already exist
now = datetime.utcnow()
for field_name in default_fields:
field = doc_cls._fields[field_name]

if isinstance(field, fof.EmbeddedDocumentField):
field = foo.create_field(field_name, **foo.get_field_kwargs(field))
else:
field = field.copy()

field._set_created_at(now)
doc_cls._declare_field(dataset, field_name, field)

# Declare existing fields
if field_docs is not None:
for field_doc in field_docs:
doc_cls._declare_field(dataset, field_doc.name, field_doc)


def _load_clips_source_dataset(frame_collection_name):
# All clips datasets have a source dataset with the same frame collection
query = {
Expand All @@ -8369,12 +8345,12 @@ def _load_clips_source_dataset(frame_collection_name):
return load_dataset(doc["name"])


def _load_dataset(obj, name, virtual=False):
def _load_dataset(obj, name, reference, virtual=False):
if not virtual:
fomi.migrate_dataset_if_necessary(name)

try:
return _do_load_dataset(obj, name)
return _do_load_dataset(obj, name, reference)
except Exception as e:
try:
version = fomi.get_dataset_revision(name)
Expand All @@ -8391,7 +8367,7 @@ def _load_dataset(obj, name, virtual=False):
raise e


def _do_load_dataset(obj, name):
def _do_load_dataset(obj, name, reference):
# pylint: disable=no-member
db = foo.get_db_conn()
res = db.datasets.find_one({"name": name})
Expand All @@ -8403,7 +8379,7 @@ def _do_load_dataset(obj, name):
frame_collection_name = dataset_doc.frame_collection_name

sample_doc_cls = _create_sample_document_cls(
obj, sample_collection_name, field_docs=dataset_doc.sample_fields
obj, sample_collection_name, reference=reference, field_docs=dataset_doc.sample_fields
)

if sample_collection_name.startswith("clips."):
Expand Down Expand Up @@ -10167,13 +10143,6 @@ def _handle_nested_fields(schema):
return safe_schemas


def _set_field_read_only(field_doc, read_only):
field_doc.read_only = read_only
if hasattr(field_doc, "fields"):
for _field_doc in field_doc.fields:
_set_field_read_only(_field_doc, read_only)


def _extract_archive_if_necessary(archive_path, cleanup):
dataset_dir = etau.split_archive(archive_path)[0]

Expand Down
75 changes: 75 additions & 0 deletions fiftyone/core/dataset_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@

import fiftyone.core.odm as foo
from datetime import datetime
import fiftyone.core.fields as fof


def _set_field_read_only(field_doc, read_only):
field_doc.read_only = read_only
if hasattr(field_doc, "fields"):
for _field_doc in field_doc.fields:
_set_field_read_only(_field_doc, read_only)


def _create_sample_document_cls(
dataset, sample_collection_name, reference, field_docs=None
):
if reference:
cls = type(sample_collection_name, (foo.DatasetSampleReferenceDocument,), {})
cls._dataset = dataset

name = reference.name

db = foo.get_db_conn()
res = db.datasets.find_one({"name": name})
if not res:
raise Exception(name)

dataset_doc = foo.DatasetDocument.from_dict(res)

sample_collection_name = dataset_doc.sample_collection_name

for d in dataset_doc.sample_fields:
_set_field_read_only(d, True)

cls._sample_id.document_type_obj = _create_sample_document_cls(reference, reference._sample_collection_name, None, field_docs=dataset_doc.sample_fields)
else:
cls = type(sample_collection_name, (foo.DatasetSampleDocument,), {})
cls._dataset = dataset
_declare_fields(dataset, cls, field_docs=field_docs)

return cls


def _create_frame_document_cls(
dataset, frame_collection_name, field_docs=None
):
cls = type(frame_collection_name, (foo.DatasetFrameDocument,), {})
cls._dataset = dataset

_declare_fields(dataset, cls, field_docs=field_docs)
return cls


def _declare_fields(dataset, doc_cls, field_docs=None):
default_fields = set(doc_cls._fields.keys())
if field_docs is not None:
default_fields -= {field_doc.name for field_doc in field_docs}

# Declare default fields that don't already exist
now = datetime.utcnow()
for field_name in default_fields:
field = doc_cls._fields[field_name]

if isinstance(field, fof.EmbeddedDocumentField):
field = foo.create_field(field_name, **foo.get_field_kwargs(field))
else:
field = field.copy()

field._set_created_at(now)
doc_cls._declare_field(dataset, field_name, field)

# Declare existing fields
if field_docs is not None:
for field_doc in field_docs:
doc_cls._declare_field(dataset, field_doc.name, field_doc)
4 changes: 4 additions & 0 deletions fiftyone/core/odm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
DatasetSampleDocument,
NoDatasetSampleDocument,
)
from .sample_reference import (
DatasetSampleReferenceDocument,
NoDatasetSampleReferenceDocument
)
from .utils import (
serialize_value,
deserialize_value,
Expand Down
1 change: 1 addition & 0 deletions fiftyone/core/odm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from collections import OrderedDict
import random


import fiftyone.core.fields as fof
import fiftyone.core.metadata as fom
import fiftyone.core.media as fomm
Expand Down
84 changes: 84 additions & 0 deletions fiftyone/core/odm/sample_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from collections import OrderedDict
from bson import DBRef, ObjectId

from .document import Document, SerializableDocument
from .mixins import DatasetMixin, get_default_fields, NoDatasetMixin

import fiftyone.core.fields as fof
from fiftyone.core.odm.sample import DatasetSampleDocument


class DatasetSampleReferenceDocument(DatasetMixin, Document):
meta = {"abstract": True}

_is_frames_doc = False

id = fof.ObjectIdField(required=True, primary_key=True, db_field="_id")
_sample_id = fof.ReferenceField(DatasetSampleDocument, required=True)

created_at = fof.DateTimeField(read_only=True)
last_modified_at = fof.DateTimeField(read_only=True)
_dataset_id = fof.ObjectIdField()

@property
def _sample_reference(self):
self._sample_id.reload()
return self._sample_id

def get_field(self, field_name):
if field_name not in ["id", "created_at", "last_modified_at"]:
try:
return self._sample_reference.get_field(field_name)
except AttributeError:
pass

return super().get_field(field_name)

def set_field(self, field_name, value, create=True, validate=True, dynamic=False):
if field_name in self._sample_reference.field_names:
raise Exception("read only!!")
return super().set_field(field_name, value, create, validate, dynamic)


class NoDatasetSampleReferenceDocument(NoDatasetMixin, SerializableDocument):
_is_frames_doc = False

# pylint: disable=no-member
default_fields = DatasetSampleReferenceDocument._fields
default_fields_ordered = get_default_fields(
DatasetSampleReferenceDocument, include_private=True
)

_sample_reference = None

def get_field(self, field_name):
try:
return self._sample_reference.get_field(field_name)
except AttributeError:
pass
return super().get_field(field_name)

def set_field(self, field_name, value, create=True, validate=True, dynamic=False):
if field_name in self._sample_reference.field_names:
raise Exception("read only!!")
return super().set_field(field_name, value, create, validate, dynamic)

def __init__(self, sample, **kwargs):
assert sample.in_dataset, "Sample must already be in dataset before creating reference"
kwargs["id"] = kwargs.get("id", None)
kwargs["media_type"] = sample.media_type
kwargs["_sample_id"] = DBRef(sample._doc.collection_name, ObjectId(sample.id))

self._sample_reference = sample

self._data = OrderedDict()

for field_name in self.default_fields_ordered:
value = kwargs.pop(field_name, None)

if value is None and field_name not in ("id", "_dataset_id", "_sample_id"):
value = self._get_default(self.default_fields[field_name])

self._data[field_name] = value

self._data.update(kwargs)
Loading