From 340925e26b1fa4da9d00003790634f63a89ed558 Mon Sep 17 00:00:00 2001 From: Julien Veyssier Date: Thu, 20 Jun 2024 12:23:08 +0200 Subject: [PATCH] add ContextWrite task processing provider Signed-off-by: Julien Veyssier --- lib/AppInfo/Application.php | 2 + lib/TaskProcessing/ContextWriteProvider.php | 120 ++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 lib/TaskProcessing/ContextWriteProvider.php diff --git a/lib/AppInfo/Application.php b/lib/AppInfo/Application.php index eb6d7c7d..9a785cc5 100644 --- a/lib/AppInfo/Application.php +++ b/lib/AppInfo/Application.php @@ -10,6 +10,7 @@ namespace OCA\OpenAi\AppInfo; use OCA\OpenAi\Capabilities; +use OCA\OpenAi\TaskProcessing\ContextWriteProvider; use OCA\OpenAi\TaskProcessing\HeadlineProvider; use OCA\OpenAi\TaskProcessing\STTProvider; use OCA\OpenAi\TaskProcessing\SummaryProvider; @@ -81,6 +82,7 @@ public function register(IRegistrationContext $context): void { $context->registerTaskProcessingProvider(SummaryProvider::class); $context->registerTaskProcessingProvider(HeadlineProvider::class); $context->registerTaskProcessingProvider(TopicsProvider::class); + $context->registerTaskProcessingProvider(ContextWriteProvider::class); } if ($this->config->getAppValue(Application::APP_ID, 't2i_provider_enabled', '1') === '1') { $context->registerTaskProcessingProvider(TextToImageProvider::class); diff --git a/lib/TaskProcessing/ContextWriteProvider.php b/lib/TaskProcessing/ContextWriteProvider.php new file mode 100644 index 00000000..5528f4c5 --- /dev/null +++ b/lib/TaskProcessing/ContextWriteProvider.php @@ -0,0 +1,120 @@ +openAiAPIService->getServiceName(); + } + + public function getTaskTypeId(): string { + return ContextWrite::ID; + } + + public function getExpectedRuntime(): int { + return $this->openAiAPIService->getExpTextProcessingTime(); + } + + public function getOptionalInputShape(): array { + return [ + 'temperature' => new ShapeDescriptor( + $this->l->t('Temperature'), + $this->l->t('What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.'), + EShapeType::Number + ), + 'max_tokens' => new ShapeDescriptor( + $this->l->t('Maximum tokens'), + $this->l->t('The maximum number of tokens that can be generated in the completion.'), + EShapeType::Number + ), + ]; + } + + public function getOptionalOutputShape(): array { + return []; + } + + public function process(?string $userId, array $input, callable $reportProgress): array { + $startTime = time(); + $adminModel = $this->config->getAppValue(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_COMPLETION_MODEL_ID) ?: Application::DEFAULT_COMPLETION_MODEL_ID; + + if ( + !isset($input['style_input']) || !is_string($input['style_input']) + || !isset($input['source_input']) || !is_string($input['source_input']) + ) { + throw new RuntimeException('Invalid inputs'); + } + $writingStyle = $input['style_input']; + $sourceMaterial = $input['source_input']; + + $prompt = 'You\'re a professional copywriter tasked with copying an instructed or demonstrated *WRITING STYLE*' + . ' and writing a text on the provided *SOURCE MATERIAL*.' + . " \n*WRITING STYLE*:\n$writingStyle\n\n*SOURCE MATERIAL*:\n\n$sourceMaterial\n\n" + . 'Now write a text in the same style detailed or demonstrated under *WRITING STYLE* using the *SOURCE MATERIAL*' + . ' as source of facts and instruction on what to write about.' + . ' Do not invent any facts or events yourself.' + . ' Also, use the *WRITING STYLE* as a guide for how to write the text ONLY and not as a source of facts or events.' + . ' Detect the language used in the *SOURCE_MATERIAL*. Make sure to use the same language in your response. Do not mention the language explicitly.'; + + $temperature = null; + if (isset($input['temperature']) + && (is_float($input['temperature']) || is_int($input['temperature']))) { + $temperature = $input['temperature']; + } + + $maxTokens = null; + if (isset($input['max_tokens']) && is_int($input['max_tokens'])) { + $maxTokens = $input['max_tokens']; + } + + $extraParams = $temperature === null + ? null + : ['temperature' => $temperature]; + + try { + if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { + $completion = $this->openAiAPIService->createChatCompletion($this->userId, $prompt, 1, $adminModel, $maxTokens, $extraParams); + } else { + $completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, $maxTokens, $extraParams); + } + } catch (Exception $e) { + throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage()); + } + if (count($completion) > 0) { + $endTime = time(); + $this->openAiAPIService->updateExpTextProcessingTime($endTime - $startTime); + return ['output' => array_pop($completion)]; + } + + throw new RuntimeException('No result in OpenAI/LocalAI response.'); + } +}