From 2d522c8781f8de9d06c3a6969482ac7d8c5261e3 Mon Sep 17 00:00:00 2001 From: Boris Sekachev Date: Mon, 30 May 2022 10:24:19 +0300 Subject: [PATCH] Prepare UI for attributes configuration (#4) * Prepare UI for attributes configuration * Add padding for label attributes * Update attributes inference logic Check the attributes returned by nuclio function call and reject those that have either incompatible types or values. * Update cvat-ui version, CHANGELOG.md * Enhance automatic annotation BE logic The code in lambda_manager didn't account for attributes mappings that had different names thus returning an empty set of attributes because it couldn't find the correct match. Fix this by getting proper mapping from `attrMapping` property of the input data. * Updated CHANGELOG * Updated changelog * Adjusted code & feature * A bit adjusted layout * Minor refactoring * Fixed bug when run auto annotation without 'attributes' key * Fixed a couple of minor issues * Increased access key id length * Fixed unit tests * Merged develop * Rejected unnecessary change Co-authored-by: Artem Zhivoderov --- CHANGELOG.md | 2 +- cvat-core/src/ml-model.js | 49 ++- cvat-core/src/object-state.js | 5 +- cvat-ui/package-lock.json | 4 +- cvat-ui/package.json | 2 +- .../controls-side-bar/tools-control.tsx | 116 +++++-- .../cloud-storage-form.tsx | 4 +- .../model-runner-modal/detector-runner.tsx | 291 ++++++++++++++---- .../components/model-runner-modal/styles.scss | 6 +- .../models-page/deployed-model-item.tsx | 17 +- .../models-page/deployed-models-list.tsx | 8 +- cvat-ui/src/reducers/interfaces.ts | 7 + cvat/apps/engine/media_extractors.py | 3 +- cvat/apps/lambda_manager/tests/test_lambda.py | 58 ++-- cvat/apps/lambda_manager/views.py | 87 ++++-- 15 files changed, 473 insertions(+), 186 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdeee72bd99c..4525b1108ebd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[2.2.0] - Unreleased ### Added -- TDB +- Support of attributes returned by serverless functions () based on () ### Changed - TDB diff --git a/cvat-core/src/ml-model.js b/cvat-core/src/ml-model.js index b2089a134436..68e29bc4a3ad 100644 --- a/cvat-core/src/ml-model.js +++ b/cvat-core/src/ml-model.js @@ -1,9 +1,9 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT /** - * Class representing a machine learning model + * Class representing a serverless function * @memberof module:API.cvat.classes */ class MLModel { @@ -11,6 +11,7 @@ class MLModel { this._id = data.id; this._name = data.name; this._labels = data.labels; + this._attributes = data.attributes || []; this._framework = data.framework; this._description = data.description; this._type = data.type; @@ -28,7 +29,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get id() { @@ -36,7 +37,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get name() { @@ -44,7 +45,8 @@ class MLModel { } /** - * @returns {string[]} + * @description labels supported by the model + * @type {string[]} * @readonly */ get labels() { @@ -56,7 +58,21 @@ class MLModel { } /** - * @returns {string} + * @typedef ModelAttribute + * @property {string} name + * @property {string[]} values + * @property {'select'|'number'|'checkbox'|'radio'|'text'} input_type + */ + /** + * @type {Object} + * @readonly + */ + get attributes() { + return { ...this._attributes }; + } + + /** + * @type {string} * @readonly */ get framework() { @@ -64,7 +80,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get description() { @@ -72,7 +88,7 @@ class MLModel { } /** - * @returns {module:API.cvat.enums.ModelType} + * @type {module:API.cvat.enums.ModelType} * @readonly */ get type() { @@ -80,7 +96,7 @@ class MLModel { } /** - * @returns {object} + * @type {object} * @readonly */ get params() { @@ -90,10 +106,9 @@ class MLModel { } /** - * @typedef {Object} MlModelTip + * @type {MlModelTip} * @property {string} message A short message for a user about the model - * @property {string} gif A gif URL to be shawn to a user as an example - * @returns {MlModelTip} + * @property {string} gif A gif URL to be shown to a user as an example * @readonly */ get tip() { @@ -101,14 +116,16 @@ class MLModel { } /** - * @callback onRequestStatusChange + * @typedef onRequestStatusChange * @param {string} event * @global - */ + */ /** - * @param {onRequestStatusChange} onRequestStatusChange Set canvas onChangeToolsBlockerState callback + * @param {onRequestStatusChange} onRequestStatusChange + * @instance + * @description Used to set a callback when the tool is blocked in UI * @returns {void} - */ + */ set onChangeToolsBlockerState(onChangeToolsBlockerState) { this._params.canvas.onChangeToolsBlockerState = onChangeToolsBlockerState; } diff --git a/cvat-core/src/object-state.js b/cvat-core/src/object-state.js index d1fc8784908d..5de80ded930a 100644 --- a/cvat-core/src/object-state.js +++ b/cvat-core/src/object-state.js @@ -1,4 +1,4 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -208,7 +208,8 @@ const { Source } = require('./enums'); rotation: { /** * @name rotation - * @type {number} angle measured by degrees + * @description angle measured by degrees + * @type {number} * @memberof module:API.cvat.classes.ObjectState * @throws {module:API.cvat.exceptions.ArgumentError} * @instance diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index 841a7f66de61..66beba8c0311 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "license": "MIT", "dependencies": { "@ant-design/icons": "^4.6.3", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 02caaaceccb9..12c769433a84 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index e2f8563846a9..21541bd8aed1 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -28,7 +28,7 @@ import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper'; import getCore from 'cvat-core-wrapper'; import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper'; import { - CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, + CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute, } from 'reducers/interfaces'; import { interactWithCanvas, @@ -37,9 +37,10 @@ import { updateAnnotationsAsync, createAnnotationsAsync, } from 'actions/annotation-actions'; -import DetectorRunner from 'components/model-runner-modal/detector-runner'; +import DetectorRunner, { DetectorRequestBody } from 'components/model-runner-modal/detector-runner'; import LabelSelector from 'components/label-selector/label-selector'; import CVATTooltip from 'components/common/cvat-tooltip'; +import { Attribute, Label } from 'components/labels-editor/common'; import ApproximationAccuracy, { thresholdFromAccuracy, @@ -374,7 +375,7 @@ export class ToolsControlComponent extends React.PureComponent { } setTimeout(() => this.runInteractionRequest(interactionId)); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Interaction error occured', @@ -466,7 +467,7 @@ export class ToolsControlComponent extends React.PureComponent { // update annotations on a canvas fetchAnnotations(); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Tracking error occured', @@ -706,7 +707,7 @@ export class ToolsControlComponent extends React.PureComponent { Array.prototype.push.apply(statefullContainer.states, serverlessStates); trackingData.statefull[trackerID] = statefullContainer; delete trackingData.stateless[trackerID]; - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracker initialization error', description: error.toString(), @@ -757,7 +758,7 @@ export class ToolsControlComponent extends React.PureComponent { trackedShape.shapePoints = shape; }); } - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracking error', description: error.toString(), @@ -1022,41 +1023,106 @@ export class ToolsControlComponent extends React.PureComponent { }); }); + function checkAttributesCompatibility( + functionAttribute: ModelAttribute | undefined, + dbAttribute: Attribute | undefined, + value: string, + ): boolean { + if (!dbAttribute || !functionAttribute) { + return false; + } + + const { inputType } = (dbAttribute as any as { inputType: string }); + if (functionAttribute.input_type === inputType) { + if (functionAttribute.input_type === 'number') { + const [min, max, step] = dbAttribute.values; + return !Number.isNaN(+value) && +value >= +min && +value <= +max && !(+value % +step); + } + + if (functionAttribute.input_type === 'checkbox') { + return ['true', 'false'].includes(value.toLowerCase()); + } + + if (['select', 'radio'].includes(functionAttribute.input_type)) { + return dbAttribute.values.includes(value); + } + + return true; + } + + switch (functionAttribute.input_type) { + case 'number': + return dbAttribute.values.includes(value) || inputType === 'text'; + case 'text': + return ['select', 'radio'].includes(dbAttribute.input_type) && dbAttribute.values.includes(value); + case 'select': + return (inputType === 'radio' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'radio': + return (inputType === 'select' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'checkbox': + return dbAttribute.values.includes(value) || inputType === 'text'; + default: + return false; + } + } + return ( { + runInference={async (model: Model, body: DetectorRequestBody) => { try { this.setState({ mode: 'detection', fetching: true }); const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame }); const states = result.map( - (data: any): any => new core.classes.ObjectState({ - shapeType: data.type, - label: jobInstance.labels.filter((label: any): boolean => label.name === data.label)[0], - points: data.points, - objectType: ObjectType.SHAPE, - frame, - occluded: false, - source: 'auto', - attributes: (data.attributes as { name: string, value: string }[]) - .reduce((mapping, attr) => { - mapping[attrsMap[data.label][attr.name]] = attr.value; - return mapping; - }, {} as Record), - zOrder: curZOrder, - }), - ); + (data: any): any => { + const jobLabel = (jobInstance.labels as Label[]) + .find((jLabel: Label): boolean => jLabel.name === data.label); + const [modelLabel] = Object.entries(body.mapping) + .find(([, { name }]) => name === data.label) || []; + + if (!jobLabel || !modelLabel) return null; + + return new core.classes.ObjectState({ + shapeType: data.type, + label: jobLabel, + points: data.points, + objectType: ObjectType.SHAPE, + frame, + occluded: false, + source: 'auto', + attributes: (data.attributes as { name: string, value: string }[]) + .reduce((acc, attr) => { + const [modelAttr] = Object.entries(body.mapping[modelLabel].attributes) + .find((value: string[]) => value[1] === attr.name) || []; + const areCompatible = checkAttributesCompatibility( + model.attributes[modelLabel].find((mAttr) => mAttr.name === modelAttr), + jobLabel.attributes.find((jobAttr: Attribute) => ( + jobAttr.name === attr.name + )), + attr.value, + ); + + if (areCompatible) { + acc[attrsMap[data.label][attr.name]] = attr.value; + } + + return acc; + }, {} as Record), + zOrder: curZOrder, + }); + }, + ).filter((state: any) => state); createAnnotations(jobInstance, frame, states); const { onSwitchToolsBlockerState } = this.props; onSwitchToolsBlockerState({ buttonVisible: false }); - } catch (error) { + } catch (error: any) { notification.error({ description: error.toString(), - message: 'Detection error occured', + message: 'Detection error occurred', }); } finally { this.setState({ fetching: false }); diff --git a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx index d4724a238e02..06d2719b7b68 100644 --- a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx +++ b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx @@ -74,7 +74,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { const fakeCredentialsData = { accountName: 'X'.repeat(24), sessionToken: 'X'.repeat(300), - key: 'X'.repeat(20), + key: 'X'.repeat(128), secretKey: 'X'.repeat(40), keyFile: new File([], 'fakeKey.json'), }; @@ -332,7 +332,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { {...internalCommonProps} > setKeyVisibility(true)} onFocus={() => onFocusCredentialsItem('key', 'key')} diff --git a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx index 829328100d37..5f24b9fcba76 100644 --- a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx +++ b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx @@ -14,8 +14,10 @@ import InputNumber from 'antd/lib/input-number'; import Button from 'antd/lib/button'; import notification from 'antd/lib/notification'; -import { Model, StringObject } from 'reducers/interfaces'; +import { Model, ModelAttribute, StringObject } from 'reducers/interfaces'; + import CVATTooltip from 'components/common/cvat-tooltip'; +import { Label as LabelInterface } from 'components/labels-editor/common'; import { clamp } from 'utils/math'; import consts from 'consts'; import { DimensionType } from '../../reducers/interfaces'; @@ -23,28 +25,40 @@ import { DimensionType } from '../../reducers/interfaces'; interface Props { withCleanup: boolean; models: Model[]; - labels: any[]; + labels: LabelInterface[]; dimension: DimensionType; runInference(model: Model, body: object): void; } +interface MappedLabel { + name: string; + attributes: StringObject; +} + +type MappedLabelsList = Record; + +export interface DetectorRequestBody { + mapping: MappedLabelsList; + cleanup: boolean; +} + +interface Match { + model: string | null; + task: string | null; +} + function DetectorRunner(props: Props): JSX.Element { const { models, withCleanup, labels, dimension, runInference, } = props; const [modelID, setModelID] = useState(null); - const [mapping, setMapping] = useState({}); + const [mapping, setMapping] = useState({}); const [threshold, setThreshold] = useState(0.5); const [distance, setDistance] = useState(50); const [cleanup, setCleanup] = useState(false); - const [match, setMatch] = useState<{ - model: string | null; - task: string | null; - }>({ - model: null, - task: null, - }); + const [match, setMatch] = useState({ model: null, task: null }); + const [attrMatches, setAttrMatch] = useState>({}); const model = models.filter((_model): boolean => _model.id === modelID)[0]; const isDetector = model && model.type === 'detector'; @@ -57,24 +71,47 @@ function DetectorRunner(props: Props): JSX.Element { if (model && model.type !== 'reid' && !model.labels.length) { notification.warning({ - message: 'The selected model does not include any lables', + message: 'The selected model does not include any labels', }); } + function matchAttributes( + labelAttributes: LabelInterface['attributes'], + modelAttributes: ModelAttribute[], + ): StringObject { + if (Array.isArray(labelAttributes) && Array.isArray(modelAttributes)) { + return labelAttributes + .reduce((attrAcc: StringObject, attr: any): StringObject => { + if (modelAttributes.some((mAttr) => mAttr.name === attr.name)) { + attrAcc[attr.name] = attr.name; + } + + return attrAcc; + }, {}); + } + + return {}; + } + function updateMatch(modelLabel: string | null, taskLabel: string | null): void { - if (match.model && taskLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[match.model] = taskLabel; - setMapping({ ...mapping, ...newmatch }); + function addMatch(modelLbl: string, taskLbl: string): void { + const newMatch: MappedLabelsList = {}; + const label = labels.find((l) => l.name === taskLbl) as LabelInterface; + const currentModel = models.filter((_model): boolean => _model.id === modelID)[0]; + const attributes = matchAttributes(label.attributes, currentModel.attributes[modelLbl]); + + newMatch[modelLbl] = { name: taskLbl, attributes }; + setMapping({ ...mapping, ...newMatch }); setMatch({ model: null, task: null }); + } + + if (match.model && taskLabel) { + addMatch(match.model, taskLabel); return; } if (match.task && modelLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[modelLabel] = match.task; - setMapping({ ...mapping, ...newmatch }); - setMatch({ model: null, task: null }); + addMatch(modelLabel, match.task); return; } @@ -84,14 +121,72 @@ function DetectorRunner(props: Props): JSX.Element { }); } + function updateAttrMatch(modelLabel: string, modelAttrLabel: string | null, taskAttrLabel: string | null): void { + function addAttributeMatch(modelAttr: string, attrLabel: string): void { + const newMatch: StringObject = {}; + newMatch[modelAttr] = attrLabel; + mapping[modelLabel].attributes = { ...mapping[modelLabel].attributes, ...newMatch }; + + delete attrMatches[modelLabel]; + setAttrMatch({ ...attrMatches }); + } + + const modelAttr = attrMatches[modelLabel]?.model; + if (modelAttr && taskAttrLabel) { + addAttributeMatch(modelAttr, taskAttrLabel); + return; + } + + const taskAttrModel = attrMatches[modelLabel]?.task; + if (taskAttrModel && modelAttrLabel) { + addAttributeMatch(modelAttrLabel, taskAttrModel); + return; + } + + attrMatches[modelLabel] = { + model: modelAttrLabel, + task: taskAttrLabel, + }; + setAttrMatch({ ...attrMatches }); + } + + function renderMappingRow( + color: string, + leftLabel: string, + rightLabel: string, + removalTitle: string, + onClick: () => void, + className = '', + ): JSX.Element { + return ( + + + {leftLabel} + + + {rightLabel} + + + + + + + + ); + } + function renderSelector( value: string, tooltip: string, labelsToRender: string[], onChange: (label: string) => void, + className = '', ): JSX.Element { return ( - + {model.labels.map( (label): JSX.Element => ( diff --git a/cvat-ui/src/components/models-page/deployed-models-list.tsx b/cvat-ui/src/components/models-page/deployed-models-list.tsx index 6db6b881e002..8b49cd66b09b 100644 --- a/cvat-ui/src/components/models-page/deployed-models-list.tsx +++ b/cvat-ui/src/components/models-page/deployed-models-list.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -29,13 +29,13 @@ export default function DeployedModelsListComponent(props: Props): JSX.Element { Name - + Type - + Description - + Labels diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index 64d2d4661ea0..c760c78a84f1 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -255,10 +255,17 @@ export interface ShareState { root: ShareItem; } +export interface ModelAttribute { + name: string; + values: string[]; + input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; +} + export interface Model { id: string; name: string; labels: string[]; + attributes: Record; framework: string; description: string; type: string; diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 716b9b66dd97..7cf11a323c71 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -95,6 +95,7 @@ def rotate_within_exif(img: Image): ORIENTATION.MIRROR_HORIZONTAL_270_ROTATED ,ORIENTATION.MIRROR_HORIZONTAL_90_ROTATED, ]: img = img.transpose(Image.FLIP_LEFT_RIGHT) + return img class IMediaReader(ABC): @@ -125,8 +126,8 @@ def _get_preview(obj): preview = Image.open(obj) else: preview = obj - preview.thumbnail(PREVIEW_SIZE) preview = rotate_within_exif(preview) + preview.thumbnail(PREVIEW_SIZE) return preview.convert('RGB') diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 4a8699ea32c1..831ff6821e77 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -324,7 +324,7 @@ def test_api_v2_lambda_requests_read(self): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task) @@ -364,7 +364,7 @@ def test_api_v2_lambda_requests_delete_finished_request(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f'{LAMBDA_REQUESTS_PATH}', self.admin, data) @@ -404,7 +404,7 @@ def test_api_v2_lambda_requests_create(self): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -414,7 +414,7 @@ def test_api_v2_lambda_requests_create(self): "quality": "compressed", "max_distance": 70, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -442,7 +442,7 @@ def test_api_v2_lambda_requests_create_non_unique_labels(self, mock_http): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -461,7 +461,7 @@ def test_api_v2_lambda_requests_create_without_function(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -474,7 +474,7 @@ def test_api_v2_lambda_requests_create_wrong_id_function(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -488,7 +488,7 @@ def test_api_v2_lambda_requests_create_two_requests(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -514,7 +514,7 @@ def test_api_v2_lambda_requests_create_without_cleanup(self): "function": id_function_detector, "task": self.main_task["id"], "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -540,7 +540,7 @@ def test_api_v2_lambda_requests_create_without_task(self): "function": id_function_detector, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -553,7 +553,7 @@ def test_api_v2_lambda_requests_create_wrong_id_task(self): "task": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -569,7 +569,7 @@ def test_api_v2_lambda_requests_create_is_not_ready(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -584,7 +584,7 @@ def test_api_v2_lambda_functions_create_detector(self): "cleanup": True, "threshold": 0.55, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -592,7 +592,7 @@ def test_api_v2_lambda_functions_create_detector(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -612,7 +612,7 @@ def test_api_v2_lambda_functions_create_user_assigned_to_no_user(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data) @@ -753,7 +753,7 @@ def test_api_v2_lambda_functions_create_non_type(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -767,7 +767,7 @@ def test_api_v2_lambda_functions_create_wrong_type(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -781,7 +781,7 @@ def test_api_v2_lambda_functions_create_unknown_type(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -796,7 +796,7 @@ def test_api_v2_lambda_functions_create_non_unique_labels(self, mock_http): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -814,7 +814,7 @@ def test_api_v2_lambda_functions_create_quality(self): "cleanup": True, "quality": quality, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -827,7 +827,7 @@ def test_api_v2_lambda_functions_create_quality(self): "cleanup": True, "quality": "test-error-quality", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -857,7 +857,7 @@ def test_api_v2_lambda_functions_create_detector_without_cleanup(self): "task": self.main_task["id"], "frame": 0, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -879,7 +879,7 @@ def test_api_v2_lambda_functions_create_detector_without_task(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -891,7 +891,7 @@ def test_api_v2_lambda_functions_create_detector_without_id_frame(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -904,7 +904,7 @@ def test_api_v2_lambda_functions_create_wrong_id_function(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data) @@ -917,7 +917,7 @@ def test_api_v2_lambda_functions_create_wrong_id_task(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -931,7 +931,7 @@ def test_api_v2_lambda_functions_create_detector_wrong_id_frame(self): "frame": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -945,7 +945,7 @@ def test_api_v2_lambda_functions_create_two_functions(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -959,7 +959,7 @@ def test_api_v2_lambda_functions_create_function_is_not_ready(self): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index acd402a1d8da..47b80e1fb3e9 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -163,21 +163,24 @@ def to_dict(self): def invoke(self, db_task, data): try: payload = {} + data = {k: v for k,v in data.items() if v is not None} threshold = data.get("threshold") if threshold: - payload.update({ - "threshold": threshold, - }) + payload.update({ "threshold": threshold }) quality = data.get("quality") mapping = data.get("mapping", {}) - mapping_by_default = {} + task_attributes = {} + mapping_by_default = {} for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all(): - mapping_by_default[db_label.name] = db_label.name + mapping_by_default[db_label.name] = { + 'name': db_label.name, + 'attributes': {} + } task_attributes[db_label.name] = {} for attribute in db_label.attributespec_set.all(): task_attributes[db_label.name][attribute.name] = { - 'input_rype': attribute.input_type, + 'input_type': attribute.input_type, 'values': attribute.values.split('\n') } if not mapping: @@ -186,15 +189,27 @@ def invoke(self, db_task, data): mapping = mapping_by_default else: # filter labels in mapping which don't exist in the task - mapping = {k:v for k,v in mapping.items() if v in mapping_by_default} + mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default} + + attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping } + mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping } + supported_attrs = {} for func_label, func_attrs in self.func_attributes.items(): - if func_label in mapping: - supported_attrs[func_label] = {} - task_attr_names = [task_attr for task_attr in task_attributes[mapping[func_label]]] + if func_label not in mapping: + continue + + mapped_label = mapping[func_label] + mapped_attributes = attr_mapping.get(func_label, {}) + supported_attrs[func_label] = {} + + if mapped_attributes: + task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]] for attr in func_attrs: - if attr['name'] in task_attr_names: - supported_attrs[func_label].update({attr["name"] : attr}) + mapped_attr = mapped_attributes.get(attr["name"]) + if mapped_attr in task_attr_names: + supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) + if self.kind == LambdaType.DETECTOR: payload.update({ "image": self._get_image(db_task, data["frame"], quality) @@ -259,29 +274,43 @@ def check_attr_value(value, func_attr, db_attr): return db_attr_type == "text" or \ (db_attr_type in ["select", "radio"] and len(value.split(" ")) == 1) elif func_attr_type == "select": - return db_attr["input_type"] in ["radio", "text"] + return db_attr_type in ["radio", "text"] elif func_attr_type == "radio": - return db_attr["input_type"] in ["select", "text"] + return db_attr_type in ["select", "text"] elif func_attr_type == "checkbox": return value in ["true", "false"] else: return False if self.kind == LambdaType.DETECTOR: for item in response: - if item['label'] in mapping: - attributes = deepcopy(item.get("attributes", [])) - item["attributes"] = [] - for attr in attributes: - db_attr = supported_attrs.get(item['label'], {}).get(attr["name"]) - func_attr = [func_attr for func_attr in self.func_attributes.get(item['label'], []) if func_attr['name'] == attr["name"]] - # Skip current attribute if it was not declared as supportd in function config - if not func_attr: - continue - if attr["name"] in supported_attrs.get(item['label'], {}) and check_attr_value(attr["value"], func_attr[0], db_attr): - item["attributes"].append(attr) - item['label'] = mapping[item['label']] - response_filtered.append(item) - response = response_filtered + item_label = item['label'] + + if item_label not in mapping: + continue + + attributes = deepcopy(item.get("attributes", [])) + item["attributes"] = [] + mapped_attributes = attr_mapping[item_label] + + for attr in attributes: + if attr['name'] not in mapped_attributes: + continue + + func_attr = [func_attr for func_attr in self.func_attributes.get(item_label, []) if func_attr['name'] == attr["name"]] + # Skip current attribute if it was not declared as supported in function config + if not func_attr: + continue + + db_attr = supported_attrs.get(item_label, {}).get(attr["name"]) + + if check_attr_value(attr["value"], func_attr[0], db_attr): + attr["name"] = mapped_attributes[attr['name']] + item["attributes"].append(attr) + + item['label'] = mapping[item['label']] + response_filtered.append(item) + response = response_filtered + return response def _get_image(self, db_task, frame, quality): @@ -444,7 +473,7 @@ def reset(self): for frame in range(db_task.data.size): annotations = function.invoke(db_task, data={ "frame": frame, "quality": quality, "mapping": mapping, - "threshold": threshold}) + "threshold": threshold }) progress = (frame + 1) / db_task.data.size if not LambdaJob._update_progress(progress): break