Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: migrate person repository to kysely #15242

Merged
merged 6 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions e2e/src/api/specs/person.e2e-spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ describe('/people', () => {
expect(body).toMatchObject({
id: expect.any(String),
name: 'New Person',
birthDate: '1990-01-01',
birthDate: '1990-01-01T00:00:00.000Z',
});
});
});
Expand Down Expand Up @@ -244,7 +244,7 @@ describe('/people', () => {
.set('Authorization', `Bearer ${admin.accessToken}`)
.send({ birthDate: '1990-01-01' });
expect(status).toBe(200);
expect(body).toMatchObject({ birthDate: '1990-01-01' });
expect(body).toMatchObject({ birthDate: '1990-01-01T00:00:00.000Z' });
});

it('should clear a date of birth', async () => {
Expand Down
6 changes: 3 additions & 3 deletions machine-learning/app/models/clip/textual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@

from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import clean_text
from app.models.transforms import clean_text, serialize_np_array
from app.schemas import ModelSession, ModelTask, ModelType


class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)

def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]:
def _predict(self, inputs: str, **kwargs: Any) -> str:
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
return res
return serialize_np_array(res)

def _load(self) -> ModelSession:
session = super()._load()
Expand Down
14 changes: 11 additions & 3 deletions machine-learning/app/models/clip/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@

from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
from app.models.transforms import (
crop_pil,
decode_pil,
get_pil_resampling,
normalize,
resize_pil,
serialize_np_array,
to_numpy,
)
from app.schemas import ModelSession, ModelTask, ModelType


class BaseCLIPVisualEncoder(InferenceModel):
depends = []
identity = (ModelType.VISUAL, ModelTask.SEARCH)

def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> str:
image = decode_pil(inputs)
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
return res
return serialize_np_array(res)

@abstractmethod
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from app.config import log, settings
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.models.transforms import decode_cv2, serialize_np_array
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType


Expand Down Expand Up @@ -61,7 +61,7 @@ def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32
return [
{
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"embedding": embedding,
"embedding": serialize_np_array(embedding),
"score": score,
}
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
Expand Down
7 changes: 7 additions & 0 deletions machine-learning/app/models/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cv2
import numpy as np
import orjson
from numpy.typing import NDArray
from PIL import Image

Expand Down Expand Up @@ -69,3 +70,9 @@ def clean_text(text: str, canonicalize: bool = False) -> str:
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text


# this allows the client to use the array as a string without deserializing only to serialize back to a string
# TODO: use this in a less invasive way
def serialize_np_array(arr: NDArray[np.float32]) -> str:
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()
2 changes: 1 addition & 1 deletion machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FaceDetectionOutput(TypedDict):

class DetectedFace(TypedDict):
boundingBox: BoundingBox
embedding: npt.NDArray[np.float32]
embedding: str
score: float


Expand Down
45 changes: 28 additions & 17 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cv2
import numpy as np
import onnxruntime as ort
import orjson
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -346,11 +347,11 @@ def test_basic_image(
mocked.run.return_value = [[self.embedding]]

clip_encoder = OpenClipVisualEncoder("ViT-B-32__openai", cache_dir="test_cache")
embedding = clip_encoder.predict(pil_image)

assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
embedding_str = clip_encoder.predict(pil_image)
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == clip_model_cfg["embed_dim"]
mocked.run.assert_called_once()

def test_basic_text(
Expand All @@ -368,11 +369,11 @@ def test_basic_text(
mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True)

clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
embedding = clip_encoder.predict("test search query")

assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
embedding_str = clip_encoder.predict("test search query")
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == clip_model_cfg["embed_dim"]
mocked.run.assert_called_once()

def test_openclip_tokenizer(
Expand Down Expand Up @@ -508,8 +509,11 @@ def test_recognition(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
assert isinstance(face.get("boundingBox"), dict)
assert set(face["boundingBox"]) == {"x1", "y1", "x2", "y2"}
assert all(isinstance(val, np.float32) for val in face["boundingBox"].values())
assert isinstance(face.get("embedding"), np.ndarray)
assert face["embedding"].shape[0] == 512
embedding_str = face.get("embedding")
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == 512
assert isinstance(face.get("score", None), np.float32)

rec_model.get_feat.assert_called_once()
Expand Down Expand Up @@ -880,8 +884,10 @@ def test_clip_image_endpoint(
actual = response.json()
assert response.status_code == 200
assert isinstance(actual, dict)
assert isinstance(actual.get("clip", None), list)
assert np.allclose(expected, actual["clip"])
embedding = actual.get("clip", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected, parsed_embedding)

def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
expected = responses["clip"]["text"]
Expand All @@ -901,8 +907,10 @@ def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestC
actual = response.json()
assert response.status_code == 200
assert isinstance(actual, dict)
assert isinstance(actual.get("clip", None), list)
assert np.allclose(expected, actual["clip"])
embedding = actual.get("clip", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected, parsed_embedding)

def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
byte_image = BytesIO()
Expand Down Expand Up @@ -933,5 +941,8 @@ def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any],

for expected_face, actual_face in zip(responses["facial-recognition"], actual["facial-recognition"]):
assert expected_face["boundingBox"] == actual_face["boundingBox"]
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
embedding = actual_face.get("embedding", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected_face["embedding"], parsed_embedding)
assert np.allclose(expected_face["score"], actual_face["score"])
1 change: 1 addition & 0 deletions server/src/decorators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ export const DummyValue = {
DATE: new Date(),
TIME_BUCKET: '2024-01-01T00:00:00.000Z',
BOOLEAN: true,
VECTOR: '[1, 2, 3]',
};

export const GENERATE_SQL_KEY = 'generate-sql-key';
Expand Down
8 changes: 2 additions & 6 deletions server/src/entities/face-search.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ export class FaceSearchEntity {
faceId!: string;

@Index('face_index', { synchronize: false })
@Column({
type: 'float4',
array: true,
transformer: { from: JSON.parse, to: (v) => `[${v}]` },
})
embedding!: number[];
@Column({ type: 'float4', array: true })
embedding!: string;
}
4 changes: 2 additions & 2 deletions server/src/entities/smart-search.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ export class SmartSearchEntity {
assetId!: string;

@Index('clip_index', { synchronize: false })
@Column({ type: 'float4', array: true, transformer: { from: JSON.parse, to: (v) => v } })
embedding!: number[];
@Column({ type: 'float4', array: true })
embedding!: string;
}
10 changes: 5 additions & 5 deletions server/src/interfaces/machine-learning.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ export type FaceDetectionOptions = ModelOptions & { minScore: number };

type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;

export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };

export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: {
Expand All @@ -42,7 +42,7 @@ export type FacialRecognitionRequest = {

export interface Face {
boundingBox: BoundingBox;
embedding: number[];
embedding: string;
score: number;
}

Expand All @@ -51,7 +51,7 @@ export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;

export interface IMachineLearningRepository {
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(urls: string[], text: string, config: ModelOptions): Promise<number[]>;
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<string>;
encodeText(urls: string[], text: string, config: ModelOptions): Promise<string>;
detectFaces(urls: string[], imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
}
25 changes: 14 additions & 11 deletions server/src/interfaces/person.interface.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Insertable, Updateable } from 'kysely';
import { AssetFaces, FaceSearch, Person } from 'src/db';
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { PersonEntity } from 'src/entities/person.entity';
import { SourceType } from 'src/enum';
import { Paginated, PaginationOptions } from 'src/utils/pagination';
import { FindManyOptions, FindOptionsRelations, FindOptionsSelect } from 'typeorm';
import { FindOptionsRelations } from 'typeorm';

export const IPersonRepository = 'IPersonRepository';

Expand Down Expand Up @@ -48,29 +49,31 @@ export interface DeleteFacesOptions {

export type UnassignFacesOptions = DeleteFacesOptions;

export type SelectFaceOptions = Partial<{ [K in keyof AssetFaceEntity]: boolean }>;

export interface IPersonRepository {
getAll(pagination: PaginationOptions, options?: FindManyOptions<PersonEntity>): Paginated<PersonEntity>;
getAll(options?: Partial<PersonEntity>): AsyncIterableIterator<PersonEntity>;
getAllForUser(pagination: PaginationOptions, userId: string, options: PersonSearchOptions): Paginated<PersonEntity>;
getAllWithoutFaces(): Promise<PersonEntity[]>;
getById(personId: string): Promise<PersonEntity | null>;
getByName(userId: string, personName: string, options: PersonNameSearchOptions): Promise<PersonEntity[]>;
getDistinctNames(userId: string, options: PersonNameSearchOptions): Promise<PersonNameResponse[]>;

create(person: Partial<PersonEntity>): Promise<PersonEntity>;
createAll(people: Partial<PersonEntity>[]): Promise<string[]>;
create(person: Insertable<Person>): Promise<PersonEntity>;
createAll(people: Insertable<Person>[]): Promise<string[]>;
delete(entities: PersonEntity[]): Promise<void>;
deleteFaces(options: DeleteFacesOptions): Promise<void>;
refreshFaces(
facesToAdd: Partial<AssetFaceEntity>[],
facesToAdd: Insertable<AssetFaces>[],
faceIdsToRemove: string[],
embeddingsToAdd?: FaceSearchEntity[],
embeddingsToAdd?: Insertable<FaceSearch>[],
): Promise<void>;
getAllFaces(pagination: PaginationOptions, options?: FindManyOptions<AssetFaceEntity>): Paginated<AssetFaceEntity>;
getAllFaces(options?: Partial<AssetFaceEntity>): AsyncIterableIterator<AssetFaceEntity>;
getFaceById(id: string): Promise<AssetFaceEntity>;
getFaceByIdWithAssets(
id: string,
relations?: FindOptionsRelations<AssetFaceEntity>,
select?: FindOptionsSelect<AssetFaceEntity>,
select?: SelectFaceOptions,
): Promise<AssetFaceEntity | null>;
getFaces(assetId: string): Promise<AssetFaceEntity[]>;
getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
Expand All @@ -80,7 +83,7 @@ export interface IPersonRepository {
getNumberOfPeople(userId: string): Promise<PeopleStatistics>;
reassignFaces(data: UpdateFacesData): Promise<number>;
unassignFaces(options: UnassignFacesOptions): Promise<void>;
update(person: Partial<PersonEntity>): Promise<PersonEntity>;
updateAll(people: Partial<PersonEntity>[]): Promise<void>;
update(person: Updateable<Person> & { id: string }): Promise<PersonEntity>;
updateAll(people: Insertable<Person>[]): Promise<void>;
getLatestFaceDate(): Promise<string | undefined>;
}
6 changes: 3 additions & 3 deletions server/src/interfaces/search.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export interface SearchExifOptions {
}

export interface SearchEmbeddingOptions {
embedding: number[];
embedding: string;
userIds: string[];
}

Expand Down Expand Up @@ -152,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {

export interface AssetDuplicateSearch {
assetId: string;
embedding: number[];
embedding: string;
maxDistance: number;
type: AssetType;
userIds: string[];
Expand Down Expand Up @@ -192,7 +192,7 @@ export interface ISearchRepository {
searchDuplicates(options: AssetDuplicateSearch): Promise<AssetDuplicateResult[]>;
searchFaces(search: FaceEmbeddingSearch): Promise<FaceSearchResult[]>;
searchRandom(size: number, options: AssetSearchOptions): Promise<AssetEntity[]>;
upsert(assetId: string, embedding: number[]): Promise<void>;
upsert(assetId: string, embedding: string): Promise<void>;
searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]>;
getAssetsByCity(userIds: string[]): Promise<AssetEntity[]>;
deleteAllSearchEmbeddings(): Promise<void>;
Expand Down
Loading
Loading