Skip to content

Commit

Permalink
refactor: migrate person repository to kysely (immich-app#15242)
Browse files Browse the repository at this point in the history
* refactor: migrate person repository to kysely

* `asVector` begone

* linting

* fix metadata faces

* update test

---------

Co-authored-by: Alex <[email protected]>
Co-authored-by: mertalev <[email protected]>
  • Loading branch information
3 people authored and ExceptionsOccur committed Jan 22, 2025
1 parent ccf703c commit b82630c
Show file tree
Hide file tree
Showing 29 changed files with 707 additions and 739 deletions.
4 changes: 2 additions & 2 deletions e2e/src/api/specs/person.e2e-spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ describe('/people', () => {
expect(body).toMatchObject({
id: expect.any(String),
name: 'New Person',
birthDate: '1990-01-01',
birthDate: '1990-01-01T00:00:00.000Z',
});
});
});
Expand Down Expand Up @@ -244,7 +244,7 @@ describe('/people', () => {
.set('Authorization', `Bearer ${admin.accessToken}`)
.send({ birthDate: '1990-01-01' });
expect(status).toBe(200);
expect(body).toMatchObject({ birthDate: '1990-01-01' });
expect(body).toMatchObject({ birthDate: '1990-01-01T00:00:00.000Z' });
});

it('should clear a date of birth', async () => {
Expand Down
6 changes: 3 additions & 3 deletions machine-learning/app/models/clip/textual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@

from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import clean_text
from app.models.transforms import clean_text, serialize_np_array
from app.schemas import ModelSession, ModelTask, ModelType


class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)

def _predict(self, inputs: str, **kwargs: Any) -> NDArray[np.float32]:
def _predict(self, inputs: str, **kwargs: Any) -> str:
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
return res
return serialize_np_array(res)

def _load(self) -> ModelSession:
session = super()._load()
Expand Down
14 changes: 11 additions & 3 deletions machine-learning/app/models/clip/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@

from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
from app.models.transforms import (
crop_pil,
decode_pil,
get_pil_resampling,
normalize,
resize_pil,
serialize_np_array,
to_numpy,
)
from app.schemas import ModelSession, ModelTask, ModelType


class BaseCLIPVisualEncoder(InferenceModel):
depends = []
identity = (ModelType.VISUAL, ModelTask.SEARCH)

def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> str:
image = decode_pil(inputs)
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
return res
return serialize_np_array(res)

@abstractmethod
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
Expand Down
4 changes: 2 additions & 2 deletions machine-learning/app/models/facial_recognition/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from app.config import log, settings
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.models.transforms import decode_cv2, serialize_np_array
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType


Expand Down Expand Up @@ -61,7 +61,7 @@ def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32
return [
{
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"embedding": embedding,
"embedding": serialize_np_array(embedding),
"score": score,
}
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
Expand Down
7 changes: 7 additions & 0 deletions machine-learning/app/models/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cv2
import numpy as np
import orjson
from numpy.typing import NDArray
from PIL import Image

Expand Down Expand Up @@ -69,3 +70,9 @@ def clean_text(text: str, canonicalize: bool = False) -> str:
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text


# this allows the client to use the array as a string without deserializing only to serialize back to a string
# TODO: use this in a less invasive way
def serialize_np_array(arr: NDArray[np.float32]) -> str:
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()
2 changes: 1 addition & 1 deletion machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FaceDetectionOutput(TypedDict):

class DetectedFace(TypedDict):
boundingBox: BoundingBox
embedding: npt.NDArray[np.float32]
embedding: str
score: float


Expand Down
45 changes: 28 additions & 17 deletions machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import cv2
import numpy as np
import onnxruntime as ort
import orjson
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -346,11 +347,11 @@ def test_basic_image(
mocked.run.return_value = [[self.embedding]]

clip_encoder = OpenClipVisualEncoder("ViT-B-32__openai", cache_dir="test_cache")
embedding = clip_encoder.predict(pil_image)

assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
embedding_str = clip_encoder.predict(pil_image)
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == clip_model_cfg["embed_dim"]
mocked.run.assert_called_once()

def test_basic_text(
Expand All @@ -368,11 +369,11 @@ def test_basic_text(
mocker.patch("app.models.clip.textual.Tokenizer.from_file", autospec=True)

clip_encoder = OpenClipTextualEncoder("ViT-B-32__openai", cache_dir="test_cache")
embedding = clip_encoder.predict("test search query")

assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
embedding_str = clip_encoder.predict("test search query")
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == clip_model_cfg["embed_dim"]
mocked.run.assert_called_once()

def test_openclip_tokenizer(
Expand Down Expand Up @@ -508,8 +509,11 @@ def test_recognition(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
assert isinstance(face.get("boundingBox"), dict)
assert set(face["boundingBox"]) == {"x1", "y1", "x2", "y2"}
assert all(isinstance(val, np.float32) for val in face["boundingBox"].values())
assert isinstance(face.get("embedding"), np.ndarray)
assert face["embedding"].shape[0] == 512
embedding_str = face.get("embedding")
assert isinstance(embedding_str, str)
embedding = orjson.loads(embedding_str)
assert isinstance(embedding, list)
assert len(embedding) == 512
assert isinstance(face.get("score", None), np.float32)

rec_model.get_feat.assert_called_once()
Expand Down Expand Up @@ -880,8 +884,10 @@ def test_clip_image_endpoint(
actual = response.json()
assert response.status_code == 200
assert isinstance(actual, dict)
assert isinstance(actual.get("clip", None), list)
assert np.allclose(expected, actual["clip"])
embedding = actual.get("clip", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected, parsed_embedding)

def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
expected = responses["clip"]["text"]
Expand All @@ -901,8 +907,10 @@ def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestC
actual = response.json()
assert response.status_code == 200
assert isinstance(actual, dict)
assert isinstance(actual.get("clip", None), list)
assert np.allclose(expected, actual["clip"])
embedding = actual.get("clip", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected, parsed_embedding)

def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
byte_image = BytesIO()
Expand Down Expand Up @@ -933,5 +941,8 @@ def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any],

for expected_face, actual_face in zip(responses["facial-recognition"], actual["facial-recognition"]):
assert expected_face["boundingBox"] == actual_face["boundingBox"]
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
embedding = actual_face.get("embedding", None)
assert isinstance(embedding, str)
parsed_embedding = orjson.loads(embedding)
assert np.allclose(expected_face["embedding"], parsed_embedding)
assert np.allclose(expected_face["score"], actual_face["score"])
1 change: 1 addition & 0 deletions server/src/decorators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ export const DummyValue = {
DATE: new Date(),
TIME_BUCKET: '2024-01-01T00:00:00.000Z',
BOOLEAN: true,
VECTOR: '[1, 2, 3]',
};

export const GENERATE_SQL_KEY = 'generate-sql-key';
Expand Down
8 changes: 2 additions & 6 deletions server/src/entities/face-search.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ export class FaceSearchEntity {
faceId!: string;

@Index('face_index', { synchronize: false })
@Column({
type: 'float4',
array: true,
transformer: { from: JSON.parse, to: (v) => `[${v}]` },
})
embedding!: number[];
@Column({ type: 'float4', array: true })
embedding!: string;
}
4 changes: 2 additions & 2 deletions server/src/entities/smart-search.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ export class SmartSearchEntity {
assetId!: string;

@Index('clip_index', { synchronize: false })
@Column({ type: 'float4', array: true, transformer: { from: JSON.parse, to: (v) => v } })
embedding!: number[];
@Column({ type: 'float4', array: true })
embedding!: string;
}
10 changes: 5 additions & 5 deletions server/src/interfaces/machine-learning.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ export type FaceDetectionOptions = ModelOptions & { minScore: number };

type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;

export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };

export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: {
Expand All @@ -42,7 +42,7 @@ export type FacialRecognitionRequest = {

export interface Face {
boundingBox: BoundingBox;
embedding: number[];
embedding: string;
score: number;
}

Expand All @@ -51,7 +51,7 @@ export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;

export interface IMachineLearningRepository {
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(urls: string[], text: string, config: ModelOptions): Promise<number[]>;
encodeImage(urls: string[], imagePath: string, config: ModelOptions): Promise<string>;
encodeText(urls: string[], text: string, config: ModelOptions): Promise<string>;
detectFaces(urls: string[], imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
}
25 changes: 14 additions & 11 deletions server/src/interfaces/person.interface.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Insertable, Updateable } from 'kysely';
import { AssetFaces, FaceSearch, Person } from 'src/db';
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { FaceSearchEntity } from 'src/entities/face-search.entity';
import { PersonEntity } from 'src/entities/person.entity';
import { SourceType } from 'src/enum';
import { Paginated, PaginationOptions } from 'src/utils/pagination';
import { FindManyOptions, FindOptionsRelations, FindOptionsSelect } from 'typeorm';
import { FindOptionsRelations } from 'typeorm';

export const IPersonRepository = 'IPersonRepository';

Expand Down Expand Up @@ -48,29 +49,31 @@ export interface DeleteFacesOptions {

export type UnassignFacesOptions = DeleteFacesOptions;

export type SelectFaceOptions = Partial<{ [K in keyof AssetFaceEntity]: boolean }>;

export interface IPersonRepository {
getAll(pagination: PaginationOptions, options?: FindManyOptions<PersonEntity>): Paginated<PersonEntity>;
getAll(options?: Partial<PersonEntity>): AsyncIterableIterator<PersonEntity>;
getAllForUser(pagination: PaginationOptions, userId: string, options: PersonSearchOptions): Paginated<PersonEntity>;
getAllWithoutFaces(): Promise<PersonEntity[]>;
getById(personId: string): Promise<PersonEntity | null>;
getByName(userId: string, personName: string, options: PersonNameSearchOptions): Promise<PersonEntity[]>;
getDistinctNames(userId: string, options: PersonNameSearchOptions): Promise<PersonNameResponse[]>;

create(person: Partial<PersonEntity>): Promise<PersonEntity>;
createAll(people: Partial<PersonEntity>[]): Promise<string[]>;
create(person: Insertable<Person>): Promise<PersonEntity>;
createAll(people: Insertable<Person>[]): Promise<string[]>;
delete(entities: PersonEntity[]): Promise<void>;
deleteFaces(options: DeleteFacesOptions): Promise<void>;
refreshFaces(
facesToAdd: Partial<AssetFaceEntity>[],
facesToAdd: Insertable<AssetFaces>[],
faceIdsToRemove: string[],
embeddingsToAdd?: FaceSearchEntity[],
embeddingsToAdd?: Insertable<FaceSearch>[],
): Promise<void>;
getAllFaces(pagination: PaginationOptions, options?: FindManyOptions<AssetFaceEntity>): Paginated<AssetFaceEntity>;
getAllFaces(options?: Partial<AssetFaceEntity>): AsyncIterableIterator<AssetFaceEntity>;
getFaceById(id: string): Promise<AssetFaceEntity>;
getFaceByIdWithAssets(
id: string,
relations?: FindOptionsRelations<AssetFaceEntity>,
select?: FindOptionsSelect<AssetFaceEntity>,
select?: SelectFaceOptions,
): Promise<AssetFaceEntity | null>;
getFaces(assetId: string): Promise<AssetFaceEntity[]>;
getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
Expand All @@ -80,7 +83,7 @@ export interface IPersonRepository {
getNumberOfPeople(userId: string): Promise<PeopleStatistics>;
reassignFaces(data: UpdateFacesData): Promise<number>;
unassignFaces(options: UnassignFacesOptions): Promise<void>;
update(person: Partial<PersonEntity>): Promise<PersonEntity>;
updateAll(people: Partial<PersonEntity>[]): Promise<void>;
update(person: Updateable<Person> & { id: string }): Promise<PersonEntity>;
updateAll(people: Insertable<Person>[]): Promise<void>;
getLatestFaceDate(): Promise<string | undefined>;
}
6 changes: 3 additions & 3 deletions server/src/interfaces/search.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export interface SearchExifOptions {
}

export interface SearchEmbeddingOptions {
embedding: number[];
embedding: string;
userIds: string[];
}

Expand Down Expand Up @@ -152,7 +152,7 @@ export interface FaceEmbeddingSearch extends SearchEmbeddingOptions {

export interface AssetDuplicateSearch {
assetId: string;
embedding: number[];
embedding: string;
maxDistance: number;
type: AssetType;
userIds: string[];
Expand Down Expand Up @@ -192,7 +192,7 @@ export interface ISearchRepository {
searchDuplicates(options: AssetDuplicateSearch): Promise<AssetDuplicateResult[]>;
searchFaces(search: FaceEmbeddingSearch): Promise<FaceSearchResult[]>;
searchRandom(size: number, options: AssetSearchOptions): Promise<AssetEntity[]>;
upsert(assetId: string, embedding: number[]): Promise<void>;
upsert(assetId: string, embedding: string): Promise<void>;
searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]>;
getAssetsByCity(userIds: string[]): Promise<AssetEntity[]>;
deleteAllSearchEmbeddings(): Promise<void>;
Expand Down
Loading

0 comments on commit b82630c

Please sign in to comment.