Skip to content

Commit

Permalink
Merge pull request #42 from nextcloud/enh/move-stt-n-igen-to-assistan…
Browse files Browse the repository at this point in the history
…t-ui

Move STT and image generation to the assistant UI
  • Loading branch information
julien-nc authored Feb 12, 2024
2 parents 43ee520 + 9ae92f7 commit 734a051
Show file tree
Hide file tree
Showing 33 changed files with 1,304 additions and 843 deletions.
6 changes: 3 additions & 3 deletions appinfo/routes.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
['name' => 'config#setConfig', 'url' => '/config', 'verb' => 'PUT'],
['name' => 'config#setAdminConfig', 'url' => '/admin-config', 'verb' => 'PUT'],

['name' => 'assistant#getTextProcessingTaskResultPage', 'url' => '/task/view/{taskId}', 'verb' => 'GET'],
['name' => 'assistant#getAssistantTaskResultPage', 'url' => '/task/view/{metaTaskId}', 'verb' => 'GET'],
['name' => 'assistant#getAssistantTask', 'url' => '/task/{metaTaskId}', 'verb' => 'GET'],
['name' => 'assistant#runTextProcessingTask', 'url' => '/task/run', 'verb' => 'POST'],
['name' => 'assistant#scheduleTextProcessingTask', 'url' => '/task/schedule', 'verb' => 'POST'],
['name' => 'assistant#runOrScheduleTextProcessingTask', 'url' => '/task/run-or-schedule', 'verb' => 'POST'],
['name' => 'assistant#getTextProcessingResult', 'url' => '/task/{taskId}', 'verb' => 'GET'],
['name' => 'assistant#parseTextFromFile', 'url' => '/parse-file', 'verb' => 'POST'],

['name' => 'Text2Image#processPrompt', 'url' => '/i/process_prompt', 'verb' => 'POST'],
Expand All @@ -27,7 +27,7 @@
['name' => 'FreePrompt#getOutputs', 'url' => '/f/get_outputs', 'verb' => 'GET'],
['name' => 'FreePrompt#cancelGeneration', 'url' => '/f/cancel_generation', 'verb' => 'POST'],

['name' => 'SpeechToText#getResultPage', 'url' => '/stt/resultPage', 'verb' => 'GET'],
['name' => 'SpeechToText#getResultPage', 'url' => '/stt/result-page/{metaTaskId}', 'verb' => 'GET'],
['name' => 'SpeechToText#transcribeAudio', 'url' => '/stt/transcribeAudio', 'verb' => 'POST'],
['name' => 'SpeechToText#transcribeFile', 'url' => '/stt/transcribeFile', 'verb' => 'POST'],
],
Expand Down
14 changes: 6 additions & 8 deletions lib/Controller/AssistantController.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ public function __construct(
}

/**
* @param int $taskId
* @param int $metaTaskId
* @return TemplateResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function getTextProcessingTaskResultPage(int $taskId): TemplateResponse {

public function getAssistantTaskResultPage(int $metaTaskId): TemplateResponse {
if ($this->userId !== null) {
$task = $this->assistantService->getTextProcessingTask($this->userId, $taskId);
$task = $this->assistantService->getAssistantTask($this->userId, $metaTaskId);
if ($task !== null) {
$this->initialStateService->provideInitialState('task', $task->jsonSerializeCc());
return new TemplateResponse(Application::APP_ID, 'taskResultPage');
Expand All @@ -44,14 +43,13 @@ public function getTextProcessingTaskResultPage(int $taskId): TemplateResponse {
}

/**
* @param int $taskId
* @param int $metaTaskId
* @return DataResponse
*/
#[NoAdminRequired]
public function getTextProcessingResult(int $taskId): DataResponse {

public function getAssistantTask(int $metaTaskId): DataResponse {
if ($this->userId !== null) {
$task = $this->assistantService->getTextProcessingTask($this->userId, $taskId);
$task = $this->assistantService->getAssistantTask($this->userId, $metaTaskId);
if ($task !== null) {
return new DataResponse([
'task' => $task->jsonSerializeCc(),
Expand Down
4 changes: 1 addition & 3 deletions lib/Controller/FreePromptController.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
use OCP\AppFramework\Http\Attribute\NoAdminRequired;
use OCP\AppFramework\Http\Attribute\NoCSRFRequired;
use OCP\AppFramework\Http\DataResponse;
use OCP\AppFramework\Services\IInitialState;

use OCP\IL10N;
use OCP\IRequest;
Expand All @@ -23,7 +22,6 @@ public function __construct(
IRequest $request,
private FreePromptService $freePromptService,
private ?string $userId,
private IInitialState $initialStateService,
private IL10N $l10n,
) {
parent::__construct($appName, $request);
Expand All @@ -46,7 +44,7 @@ public function processPrompt(string $prompt): DataResponse {
} catch (Exception $e) {
return new DataResponse(['error' => $e->getMessage()], (int)$e->getCode());
}

return new DataResponse($result);
}

Expand Down
8 changes: 5 additions & 3 deletions lib/Controller/SpeechToTextController.php
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ public function __construct(
}

/**
* @param int $id
* @param int $metaTaskId
* @return TemplateResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function getResultPage(int $id): TemplateResponse {
public function getResultPage(int $metaTaskId): TemplateResponse {
$response = new TemplateResponse(Application::APP_ID, 'speechToTextResultPage');
try {
$initData = [
'task' => $this->internalGetTask($id),
'task' => $this->internalGetTask($metaTaskId),
];
} catch (Exception $e) {
$initData = [
Expand Down Expand Up @@ -102,6 +102,7 @@ public function getTranscript(int $id): DataResponse {
*
* @param integer $id
* @return MetaTask
* @throws Exception
*/
private function internalGetTask(int $id): MetaTask {
try {
Expand All @@ -128,6 +129,7 @@ private function internalGetTask(int $id): MetaTask {

/**
* @return DataResponse
* @throws NotPermittedException
*/
#[NoAdminRequired]
public function transcribeAudio(): DataResponse {
Expand Down
29 changes: 20 additions & 9 deletions lib/Controller/Text2ImageController.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use OCP\AppFramework\Services\IInitialState;
use OCP\Db\Exception as DbException;

use OCP\Files\NotPermittedException;
use OCP\IL10N;
use OCP\IRequest;
use OCP\TextToImage\Exception\TaskFailureException;
Expand All @@ -38,17 +39,26 @@ public function __construct(
}

/**
* @param string $appId
* @param string $identifier
* @param string $prompt
* @param int $nResults
* @param bool $displayPrompt
* @param bool $notifyReadyIfScheduled
* @param bool $schedule
* @return DataResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
public function processPrompt(string $prompt, int $nResults = 1, bool $displayPrompt = false): DataResponse {
public function processPrompt(
string $appId, string $identifier, string $prompt, int $nResults = 1, bool $displayPrompt = false,
bool $notifyReadyIfScheduled = false, bool $schedule = false
): DataResponse {
$nResults = min(10, max(1, $nResults));
try {
$result = $this->text2ImageHelperService->processPrompt($prompt, $nResults, $displayPrompt, $this->userId);
$result = $this->text2ImageHelperService->processPrompt(
$appId, $identifier, $prompt, $nResults, $displayPrompt, $this->userId, $notifyReadyIfScheduled, $schedule
);
} catch (Exception | TaskFailureException $e) {
return new DataResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}
Expand Down Expand Up @@ -92,7 +102,7 @@ public function getImage(string $imageGenId, int $fileNameId): DataDisplayRespon
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if ($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'fileId' => $fileNameId, 'status' => $e->getCode()]);
}
return $response;
}
Expand Down Expand Up @@ -125,7 +135,7 @@ public function getGenerationInfo(string $imageGenId): DataResponse {
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if ($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'status' => $e->getCode()]);
}
return $response;
}
Expand All @@ -136,12 +146,12 @@ public function getGenerationInfo(string $imageGenId): DataResponse {
/**
* @param string $imageGenId
* @param array $fileVisStatusArray
* @return DataResponse
*/
#[NoAdminRequired]
#[NoCSRFRequired]
#[BruteForceProtection(action: 'imageGenId')]
public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStatusArray): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to set visibility of image files; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -156,7 +166,7 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
$response = new DataResponse(['error' => $e->getMessage()], (int) $e->getCode());
if($e->getCode() === Http::STATUS_BAD_REQUEST || $e->getCode() === Http::STATUS_UNAUTHORIZED) {
// Throttle brute force attempts
$response->throttle(['action' => 'imageGenId']);
$response->throttle(['imageGenId' => $imageGenId, 'status' => $e->getCode()]);
}
return $response;
}
Expand All @@ -175,7 +185,6 @@ public function setVisibilityOfImageFiles(string $imageGenId, array $fileVisStat
#[NoCSRFRequired]
#[AnonRateLimit(limit: 10, period: 60)]
public function notifyWhenReady(string $imageGenId): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to notify when ready; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -187,6 +196,7 @@ public function notifyWhenReady(string $imageGenId): DataResponse {
}
return new DataResponse('success', Http::STATUS_OK);
}

/**
* Cancel image generation
*
Expand All @@ -196,12 +206,12 @@ public function notifyWhenReady(string $imageGenId): DataResponse {
*
* @param string $imageGenId
* @return DataResponse
* @throws NotPermittedException
*/
#[NoAdminRequired]
#[NoCSRFRequired]
#[AnonRateLimit(limit: 10, period: 60)]
public function cancelGeneration(string $imageGenId): DataResponse {

if ($this->userId === null) {
return new DataResponse(['error' => $this->l10n->t('Failed to cancel generation; unknown user')], Http::STATUS_INTERNAL_SERVER_ERROR);
}
Expand All @@ -216,6 +226,7 @@ public function cancelGeneration(string $imageGenId): DataResponse {
* Does not need bruteforce protection
*
* @param string|null $imageGenId
* @param bool|null $forceEditMode
* @return TemplateResponse
*/
#[NoAdminRequired]
Expand All @@ -226,7 +237,7 @@ public function showGenerationPage(?string $imageGenId, ?bool $forceEditMode = f
$forceEditMode = false;
}
$this->initialStateService->provideInitialState('generation-page-inputs', ['image_gen_id' => $imageGenId, 'force_edit_mode' => $forceEditMode]);

return new TemplateResponse(Application::APP_ID, 'imageGenerationPage');
}
}
16 changes: 5 additions & 11 deletions lib/Db/Text2Image/ImageGeneration.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
* @method \void setImageGenId(string $imageGenId)
* @method \string getPrompt()
* @method \void setPrompt(string $prompt)
* @method \void setUserId(string $userId)
* @method \string getUserId()
* @method \void setTimestamp(int $timestamp)
* @method \void setUserId(string $userId)
* @method \int getTimestamp()
* @method \void setExpGenTime(int $expGenTime)
* @method \void setTimestamp(int $timestamp)
* @method \boolean getNotifyReady()
* @method \void setNotifyReady(bool $notifyReady)
* @method \int getExpGenTime()
* @method \void setExpGenTime(int $expGenTime)
*
*/
class ImageGeneration extends Entity implements \JsonSerializable {
Expand Down Expand Up @@ -80,12 +82,4 @@ public function setFailed(?bool $failed): void {
public function getFailed(): bool {
return $this->failed === true;
}

public function setNotifyReady(?bool $notifyReady): void {
$this->notifyReady = $notifyReady === true;
}

public function getNotifyReady(): bool {
return $this->notifyReady === true;
}
}
8 changes: 6 additions & 2 deletions lib/Db/Text2Image/ImageGenerationMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,22 @@ public function getImageGenerationOfImageGenId(string $imageGenId): ImageGenerat
* @param string $prompt
* @param string $userId
* @param int|null $expCompletionTime
* @param bool $notifyReady
* @return ImageGeneration
* @throws Exception
*/
public function createImageGeneration(string $imageGenId, string $prompt = '', string $userId = '', ?int $expCompletionTime = null): ImageGeneration {
public function createImageGeneration(
string $imageGenId, string $prompt = '', string $userId = '', ?int $expCompletionTime = null,
bool $notifyReady = false
): ImageGeneration {
$imageGeneration = new ImageGeneration();
$imageGeneration->setImageGenId($imageGenId);
$imageGeneration->setTimestamp((new DateTime())->getTimestamp());
$imageGeneration->setPrompt($prompt);
$imageGeneration->setUserId($userId);
$imageGeneration->setIsGenerated(false);
$imageGeneration->setFailed(false);
$imageGeneration->setNotifyReady(false);
$imageGeneration->setNotifyReady($notifyReady);
$imageGeneration->setExpGenTime($expCompletionTime ?? (new DateTime())->getTimestamp());
return $this->insert($imageGeneration);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Notification/Notifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public function prepare(INotification $notification, string $languageCode): INot
}


$link = $params['target'] ?? $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getTextProcessingTaskResultPage', ['taskId' => $params['id']]);
$link = $params['target'] ?? $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getAssistantTaskResultPage', ['metaTaskId' => $params['id']]);
$iconUrl = $this->url->getAbsoluteURL($this->url->imagePath('core', 'actions/error.svg'));

$notification
Expand Down
12 changes: 6 additions & 6 deletions lib/Reference/Text2ImageReferenceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

use Exception;
use OCA\TpAssistant\AppInfo\Application;
use OCA\TpAssistant\Db\Text2Image\ImageGeneration;
use OCA\TpAssistant\Db\Text2Image\ImageGenerationMapper;
use OCP\Collaboration\Reference\ADiscoverableReferenceProvider;
use OCP\Collaboration\Reference\IReference;
Expand Down Expand Up @@ -74,7 +73,6 @@ public function resolveReference(string $referenceText): ?IReference {
}

try {
/** @var ImageGeneration $imageGeneration */
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($imageGenId);
} catch (Exception $e) {
$imageGeneration = null;
Expand All @@ -89,14 +87,16 @@ public function resolveReference(string $referenceText): ?IReference {
$reference = new Reference($referenceText);
$imageUrl = $this->urlGenerator->linkToRouteAbsolute(
Application::APP_ID . '.Text2Image.getGenerationInfo',
[
'imageGenId' => $imageGenId,
]
['imageGenId' => $imageGenId]
);

$reference->setImageUrl($imageUrl);

$richObjectInfo = ['prompt' => $prompt, 'proxied_url' => $imageUrl];
$richObjectInfo = [
'prompt' => $prompt,
'proxied_url' => $imageUrl,
'imageGenId' => $imageGenId,
];
$reference->setRichObject(
self::RICH_OBJECT_TYPE,
$richObjectInfo,
Expand Down
28 changes: 9 additions & 19 deletions lib/Service/AssistantService.php
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,7 @@ public function sendNotification(MetaTask $task, ?string $customTarget = null, ?
}

private function getDefaultTarget(MetaTask $task): string {
$category = $task->getCategory();
if ($category === Application::TASK_CATEGORY_TEXT_GEN) {
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getTextProcessingTaskResultPage', ['taskId' => $task->getId()]);
} elseif ($category === Application::TASK_CATEGORY_SPEECH_TO_TEXT) {
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.SpeechToText.getResultPage', ['id' => $task->getId()]);
} elseif ($category === Application::TASK_CATEGORY_TEXT_TO_IMAGE) {
$imageGeneration = $this->imageGenerationMapper->getImageGenerationOfImageGenId($task->getIdentifier());
return $this->url->linkToRouteAbsolute(
Application::APP_ID . '.Text2Image.showGenerationPage',
[
'imageGenId' => $imageGeneration->getImageGenId(),
]
);
}
return '';
return $this->url->linkToRouteAbsolute(Application::APP_ID . '.assistant.getAssistantTaskResultPage', ['metaTaskId' => $task->getId()]);
}

/**
Expand Down Expand Up @@ -168,19 +154,23 @@ private function sanitizeInputs(string $type, array $inputs): array {

/**
* @param string $userId
* @param int $taskId
* @param int $metaTaskId
* @return MetaTask|null
*/
public function getTextProcessingTask(string $userId, int $taskId): ?MetaTask {
public function getAssistantTask(string $userId, int $metaTaskId): ?MetaTask {
try {
$metaTask = $this->metaTaskMapper->getMetaTask($taskId);
$metaTask = $this->metaTaskMapper->getMetaTask($metaTaskId);
} catch (DoesNotExistException | MultipleObjectsReturnedException | \OCP\Db\Exception $e) {
return null;
}
if ($metaTask->getUserId() !== $userId) {
return null;
}
// Check if the task status is up-to-date (if not, update status and output)
// only try to update meta task status for text processing ones
if ($metaTask->getCategory() !== Application::TASK_CATEGORY_TEXT_GEN) {
return $metaTask;
}
// Check if the text processing task status is up-to-date (if not, update status and output)
try {
$ocpTask = $this->textProcessingManager->getTask($metaTask->getOcpTaskId());

Expand Down
Loading

0 comments on commit 734a051

Please sign in to comment.