Skip to content

Commit

Permalink
Merge pull request #169 from nextcloud/feat/add-proofread-provider
Browse files Browse the repository at this point in the history
Feat: add proofread provider
  • Loading branch information
julien-nc authored Jan 7, 2025
2 parents 50d91bc + d554856 commit e635b55
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/AppInfo/Application.php
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public function register(IRegistrationContext $context): void {
if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextChatWithTools')) {
$context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\TextToTextChatWithToolsProvider::class);
}
if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextProofread')) {
$context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\ProofreadProvider::class);
}
}
if ($this->appConfig->getValueString(Application::APP_ID, 't2i_provider_enabled', '1') === '1') {
$context->registerTaskProcessingProvider(TextToImageProvider::class);
Expand Down
136 changes: 136 additions & 0 deletions lib/TaskProcessing/ProofreadProvider.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
<?php

declare(strict_types=1);

namespace OCA\OpenAi\TaskProcessing;

use Exception;
use OCA\OpenAi\AppInfo\Application;
use OCA\OpenAi\Service\OpenAiAPIService;
use OCA\OpenAi\Service\OpenAiSettingsService;
use OCP\IAppConfig;
use OCP\IL10N;
use OCP\TaskProcessing\EShapeType;
use OCP\TaskProcessing\ISynchronousProvider;
use OCP\TaskProcessing\ShapeDescriptor;
use OCP\TaskProcessing\TaskTypes\TextToTextProofread;
use RuntimeException;

class ProofreadProvider implements ISynchronousProvider {

public function __construct(
private OpenAiAPIService $openAiAPIService,
private IAppConfig $appConfig,
private OpenAiSettingsService $openAiSettingsService,
private IL10N $l,
private ?string $userId,
) {
}

public function getId(): string {
return Application::APP_ID . '-text2text:proofread';
}

public function getName(): string {
return $this->openAiAPIService->getServiceName();
}

public function getTaskTypeId(): string {
return TextToTextProofread::ID;
}

public function getExpectedRuntime(): int {
return $this->openAiAPIService->getExpTextProcessingTime();
}

public function getInputShapeEnumValues(): array {
return [];
}

public function getInputShapeDefaults(): array {
return [];
}

public function getOptionalInputShape(): array {
return [
'max_tokens' => new ShapeDescriptor(
$this->l->t('Maximum output words'),
$this->l->t('The maximum number of words/tokens that can be generated in the completion.'),
EShapeType::Number
),
'model' => new ShapeDescriptor(
$this->l->t('Model'),
$this->l->t('The model used to generate the completion'),
EShapeType::Enum
),
];
}

public function getOptionalInputShapeEnumValues(): array {
return [
'model' => $this->openAiAPIService->getModelEnumValues($this->userId),
];
}

public function getOptionalInputShapeDefaults(): array {
$adminModel = $this->openAiAPIService->isUsingOpenAi()
? ($this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID)
: $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id');
return [
'max_tokens' => 1000,
'model' => $adminModel,
];
}

public function getOutputShapeEnumValues(): array {
return [];
}

public function getOptionalOutputShape(): array {
return [];
}

public function getOptionalOutputShapeEnumValues(): array {
return [];
}

public function process(?string $userId, array $input, callable $reportProgress): array {
$startTime = time();

if (!isset($input['input']) || !is_string($input['input'])) {
throw new RuntimeException('Invalid prompt');
}
$textInput = $input['input'];
$systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.';

$maxTokens = null;
if (isset($input['max_tokens']) && is_int($input['max_tokens'])) {
$maxTokens = $input['max_tokens'];
}

if (isset($input['model']) && is_string($input['model'])) {
$model = $input['model'];
} else {
$model = $this->appConfig->getValueString(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_MODEL_ID) ?: Application::DEFAULT_MODEL_ID;
}

try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $textInput, $systemPrompt, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$prompt = $systemPrompt . ' Here is the text:' . "\n\n" . $textInput;
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
} catch (Exception $e) {
throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage());
}
if (count($completion) > 0) {
$endTime = time();
$this->openAiAPIService->updateExpTextProcessingTime($endTime - $startTime);
return ['output' => array_pop($completion)];
}

throw new RuntimeException('No result in OpenAI/LocalAI response.');
}
}
64 changes: 64 additions & 0 deletions tests/unit/Providers/OpenAiProviderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use OCA\OpenAi\Service\OpenAiSettingsService;
use OCA\OpenAi\TaskProcessing\ChangeToneProvider;
use OCA\OpenAi\TaskProcessing\HeadlineProvider;
use OCA\OpenAi\TaskProcessing\ProofreadProvider;
use OCA\OpenAi\TaskProcessing\SummaryProvider;
use OCA\OpenAi\TaskProcessing\TextToTextProvider;
use OCA\OpenAi\TaskProcessing\TranslateProvider;
Expand Down Expand Up @@ -326,6 +327,69 @@ public function testSummaryProvider(): void {
$this->quotaUsageMapper->deleteUserQuotaUsages(self::TEST_USER1);
}

public function testProofreadProvider(): void {
$proofreadProvider = new ProofreadProvider(
$this->openAiApiService,
\OC::$server->get(IAppConfig::class),
$this->openAiSettingsService,
$this->createMock(\OCP\IL10N::class),
self::TEST_USER1,
);

$prompt = 'This is a test prompt';
$n = 1;

$response = '{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a test response."
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}';

$url = self::OPENAI_API_BASE . 'chat/completions';

$options = ['timeout' => Application::OPENAI_DEFAULT_REQUEST_TIMEOUT, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => self::AUTHORIZATION_HEADER, 'Content-Type' => 'application/json']];
$systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.';
$options['body'] = json_encode([
'model' => Application::DEFAULT_COMPLETION_MODEL_ID,
'messages' => [['role' => 'system', 'content' => $systemPrompt],['role' => 'user', 'content' => $prompt]],
'max_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS,
'n' => $n,
'user' => self::TEST_USER1,
]);

$iResponse = $this->createMock(\OCP\Http\Client\IResponse::class);
$iResponse->method('getBody')->willReturn($response);
$iResponse->method('getStatusCode')->willReturn(200);

$this->iClient->expects($this->once())->method('post')->with($url, $options)->willReturn($iResponse);

$result = $proofreadProvider->process(self::TEST_USER1, ['input' => $prompt], fn () => null);
$this->assertEquals('This is a test response.', $result['output']);

// Check that token usage is logged properly
$usage = $this->quotaUsageMapper->getQuotaUnitsOfUser(self::TEST_USER1, Application::QUOTA_TYPE_TEXT);
$this->assertEquals(21, $usage);
// Clear quota usage
$this->quotaUsageMapper->deleteUserQuotaUsages(self::TEST_USER1);
}

public function testTranslationProvider(): void {
$translationProvider = new TranslateProvider(
$this->openAiApiService,
Expand Down

0 comments on commit e635b55

Please sign in to comment.