From d5548560baf182bde49134bada08ab792e525bad Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Tue, 7 Jan 2025 16:39:21 +0100 Subject: [PATCH] feat(proofread): decouple the system prompt for the chat endpoint Signed-off-by: Julien Veyssier --- lib/TaskProcessing/ProofreadProvider.php | 5 +++-- tests/unit/Providers/OpenAiProviderTest.php | 10 ++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/lib/TaskProcessing/ProofreadProvider.php b/lib/TaskProcessing/ProofreadProvider.php index ad63f14..1fd4ede 100644 --- a/lib/TaskProcessing/ProofreadProvider.php +++ b/lib/TaskProcessing/ProofreadProvider.php @@ -101,7 +101,7 @@ public function process(?string $userId, array $input, callable $reportProgress) throw new RuntimeException('Invalid prompt'); } $textInput = $input['input']; - $prompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list. Here is the text:' . "\n\n" . $textInput; + $systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.'; $maxTokens = null; if (isset($input['max_tokens']) && is_int($input['max_tokens'])) { @@ -116,9 +116,10 @@ public function process(?string $userId, array $input, callable $reportProgress) try { if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { - $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens); + $completion = $this->openAiAPIService->createChatCompletion($userId, $model, $textInput, $systemPrompt, null, 1, $maxTokens); $completion = $completion['messages']; } else { + $prompt = $systemPrompt . ' Here is the text:' . "\n\n" . $textInput; $completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens); } } catch (Exception $e) { diff --git a/tests/unit/Providers/OpenAiProviderTest.php b/tests/unit/Providers/OpenAiProviderTest.php index 6ef3bcb..afd6d84 100644 --- a/tests/unit/Providers/OpenAiProviderTest.php +++ b/tests/unit/Providers/OpenAiProviderTest.php @@ -365,8 +365,14 @@ public function testProofreadProvider(): void { $url = self::OPENAI_API_BASE . 'chat/completions'; $options = ['timeout' => Application::OPENAI_DEFAULT_REQUEST_TIMEOUT, 'headers' => ['User-Agent' => Application::USER_AGENT, 'Authorization' => self::AUTHORIZATION_HEADER, 'Content-Type' => 'application/json']]; - $message = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list. Here is the text:' . "\n\n" . $prompt; - $options['body'] = json_encode(['model' => Application::DEFAULT_COMPLETION_MODEL_ID, 'messages' => [['role' => 'user', 'content' => $message]], 'max_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, 'n' => $n, 'user' => self::TEST_USER1]); + $systemPrompt = 'Proofread the following text. List all spelling and grammar mistakes and how to correct them. Output only the list.'; + $options['body'] = json_encode([ + 'model' => Application::DEFAULT_COMPLETION_MODEL_ID, + 'messages' => [['role' => 'system', 'content' => $systemPrompt],['role' => 'user', 'content' => $prompt]], + 'max_tokens' => Application::DEFAULT_MAX_NUM_OF_TOKENS, + 'n' => $n, + 'user' => self::TEST_USER1, + ]); $iResponse = $this->createMock(\OCP\Http\Client\IResponse::class); $iResponse->method('getBody')->willReturn($response);