diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index 6706968..b018bc9 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -105,6 +105,9 @@ public function register(IRegistrationContext $context): void { if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextChatWithTools')) { $context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\TextToTextChatWithToolsProvider::class); } + if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextProofread')) { + $context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\ProofreadProvider::class); + } } if ($this->appConfig->getValueString(Application::APP_ID, 't2i_provider_enabled', '1') === '1') { $context->registerTaskProcessingProvider(TextToImageProvider::class); diff --git a/lib/TaskProcessing/ProofreadProvider.php b/lib/TaskProcessing/ProofreadProvider.php new file mode 100644 index 0000000..1fd4ede --- /dev/null +++ b/lib/TaskProcessing/ProofreadProvider.php @@ -0,0 +1,136 @@ +openAiAPIService->getServiceName(); + } + + public function getTaskTypeId(): string { + return TextToTextProofread::ID; + } + + public function getExpectedRuntime(): int { + return $this->openAiAPIService->getExpTextProcessingTime(); + } + + public function getInputShapeEnumValues(): array { + return []; + } + + public function getInputShapeDefaults(): array { + return []; + } + + public function getOptionalInputShape(): array { + return [ + 'max_tokens' => new ShapeDescriptor( + $this->l->t('Maximum output words'), + $this->l->t('The maximum number of words/tokens that can be generated in the completion.'), + EShapeType::Number + ), + 'model' => new ShapeDescriptor( + $this->l->t('Model'), + $this->l->t('The model used to generate the completion'), + EShapeType::Enum + ), + ]; + } + + public function getOptionalInputShapeEnumValues(): array { + return [ + 'model' => $this->openAiAPIService->getModelEnumValues($this->userId), + ]; + } + + public function getOptionalInputShapeDefaults(): array { + $adminModel = $this->openAiAPIService->isUsingOpenAi() + ? ($this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID) + : $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id'); + return [ + 'max_tokens' => 1000, + 'model' => $adminModel, + ]; + } + + public function getOutputShapeEnumValues(): array { + return []; + } + + public function getOptionalOutputShape(): array { + return []; + } + + public function getOptionalOutputShapeEnumValues(): array { + return []; + } + + public function process(?string $userId, array $input, callable $reportProgress): array { + $startTime = time(); + + if (!isset($input['input']) || !is_string($input['input'])) { + throw new RuntimeException('Invalid prompt'); + } + $textInput = $input['input']; + $systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.'; + + $maxTokens = null; + if (isset($input['max_tokens']) && is_int($input['max_tokens'])) { + $maxTokens = $input['max_tokens']; + } + + if (isset($input['model']) && is_string($input['model'])) { + $model = $input['model']; + } else { + $model = $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID; + } + + try { + if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { + $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $textInput, $systemPrompt, null, 1, $maxTokens); + $completion = $completion['messages']; + } else { + $prompt = $systemPrompt . ' Here is the text:' . "\n\n" . $textInput; + $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); + } + } catch (Exception $e) { + throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage()); + } + if (count($completion) > 0) { + $endTime = time(); + $this->openAiAPIService->updateExpTextProcessingTime($endTime - $startTime); + return ['output' => array_pop($completion)]; + } + + throw new RuntimeException('No result in OpenAI/LocalAI response.'); + } +} diff --git a/tests/unit/Providers/OpenAiProviderTest.php b/tests/unit/Providers/OpenAiProviderTest.php index 857a857..afd6d84 100644 --- a/tests/unit/Providers/OpenAiProviderTest.php +++ b/tests/unit/Providers/OpenAiProviderTest.php @@ -17,6 +17,7 @@ use OCA\OpenAi\Service\OpenAiSettingsService; use OCA\OpenAi\TaskProcessing\ChangeToneProvider; use OCA\OpenAi\TaskProcessing\HeadlineProvider; +use OCA\OpenAi\TaskProcessing\ProofreadProvider; use OCA\OpenAi\TaskProcessing\SummaryProvider; use OCA\OpenAi\TaskProcessing\TextToTextProvider; use OCA\OpenAi\TaskProcessing\TranslateProvider; @@ -326,6 +327,69 @@ public function testSummaryProvider(): void { $this->quotaUsageMapper->deleteUserQuotaUsages(self::TEST_USER1); } + public function testProofreadProvider(): void { + $proofreadProvider = new ProofreadProvider( + $this->openAiApiService, + \OC::$server->get(IAppConfig::class), + $this->openAiSettingsService, + $this->createMock(\OCP\IL10N::class), + self::TEST_USER1, + ); + + $prompt = 'This is a test prompt'; + $n = 1; + + $response = '{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "This is a test response." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + }'; + + $url = self::OPENAI_API_BASE . 'chat/completions'; + + $options = ['timeout' => Application::OPENAI_DEFAULT_REQUEST_TIMEOUT, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => self::AUTHORIZATION_HEADER, 'Content-Type' => 'application/json']]; + $systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.'; + $options['body'] = json_encode([ + 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, + 'messages' => [['role' => 'system', 'content' => $systemPrompt],['role' => 'user', 'content' => $prompt]], + 'max_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, + 'n' => $n, + 'user' => self::TEST_USER1, + ]); + + $iResponse = $this->createMock(\OCP\Http\Client\IResponse::class); + $iResponse->method('getBody')->willReturn($response); + $iResponse->method('getStatusCode')->willReturn(200); + + $this->iClient->expects($this->once())->method('post')->with($url, $options)->willReturn($iResponse); + + $result = $proofreadProvider->process(self::TEST_USER1, ['input' => $prompt], fn () => null); + $this->assertEquals('This is a test response.', $result['output']); + + // Check that token usage is logged properly + $usage = $this->quotaUsageMapper->getQuotaUnitsOfUser(self::TEST_USER1, Application::QUOTA_TYPE_TEXT); + $this->assertEquals(21, $usage); + // Clear quota usage + $this->quotaUsageMapper->deleteUserQuotaUsages(self::TEST_USER1); + } + public function testTranslationProvider(): void { $translationProvider = new TranslateProvider( $this->openAiApiService,