-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ContextWrite task processing provider
Signed-off-by: Julien Veyssier <[email protected]>
- Loading branch information
Showing
2 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace OCA\OpenAi\TaskProcessing; | ||
|
||
use Exception; | ||
use OCA\OpenAi\AppInfo\Application; | ||
use OCA\OpenAi\Service\OpenAiAPIService; | ||
use OCA\OpenAi\Service\OpenAiSettingsService; | ||
use OCP\IConfig; | ||
use OCP\IL10N; | ||
use OCP\TaskProcessing\EShapeType; | ||
use OCP\TaskProcessing\ISynchronousProvider; | ||
use OCP\TaskProcessing\ShapeDescriptor; | ||
use OCP\TaskProcessing\TaskTypes\ContextWrite; | ||
use OCP\TaskProcessing\TaskTypes\TextToText; | ||
use RuntimeException; | ||
|
||
class ContextWriteProvider implements ISynchronousProvider { | ||
|
||
public function __construct( | ||
private OpenAiAPIService $openAiAPIService, | ||
private IConfig $config, | ||
private OpenAiSettingsService $openAiSettingsService, | ||
private IL10N $l, | ||
private ?string $userId, | ||
) { | ||
} | ||
|
||
public function getId(): string { | ||
return Application::APP_ID . '-contextwrite'; | ||
} | ||
|
||
public function getName(): string { | ||
return $this->openAiAPIService->getServiceName(); | ||
} | ||
|
||
public function getTaskTypeId(): string { | ||
return ContextWrite::ID; | ||
} | ||
|
||
public function getExpectedRuntime(): int { | ||
return $this->openAiAPIService->getExpTextProcessingTime(); | ||
} | ||
|
||
public function getOptionalInputShape(): array { | ||
return [ | ||
'temperature' => new ShapeDescriptor( | ||
$this->l->t('Temperature'), | ||
$this->l->t('What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.'), | ||
EShapeType::Number | ||
), | ||
'max_tokens' => new ShapeDescriptor( | ||
$this->l->t('Maximum tokens'), | ||
$this->l->t('The maximum number of tokens that can be generated in the completion.'), | ||
EShapeType::Number | ||
), | ||
]; | ||
} | ||
|
||
public function getOptionalOutputShape(): array { | ||
return []; | ||
} | ||
|
||
public function process(?string $userId, array $input, callable $reportProgress): array { | ||
$startTime = time(); | ||
$adminModel = $this->config->getAppValue(Application::APP_ID, 'default_completion_model_id', Application::DEFAULT_COMPLETION_MODEL_ID) ?: Application::DEFAULT_COMPLETION_MODEL_ID; | ||
|
||
if ( | ||
!isset($input['style_input']) || !is_string($input['style_input']) | ||
|| !isset($input['source_input']) || !is_string($input['source_input']) | ||
) { | ||
throw new RuntimeException('Invalid inputs'); | ||
} | ||
$writingStyle = $input['style_input']; | ||
$sourceMaterial = $input['source_input']; | ||
|
||
$prompt = 'You\'re a professional copywriter tasked with copying an instructed or demonstrated *WRITING STYLE*' | ||
. ' and writing a text on the provided *SOURCE MATERIAL*.' | ||
. " \n*WRITING STYLE*:\n$writingStyle\n\n*SOURCE MATERIAL*:\n\n$sourceMaterial\n\n" | ||
. 'Now write a text in the same style detailed or demonstrated under *WRITING STYLE* using the *SOURCE MATERIAL*' | ||
. ' as source of facts and instruction on what to write about.' | ||
. ' Do not invent any facts or events yourself.' | ||
. ' Also, use the *WRITING STYLE* as a guide for how to write the text ONLY and not as a source of facts or events.' | ||
. ' Detect the language used in the *SOURCE_MATERIAL*. Make sure to use the same language in your response. Do not mention the language explicitly.'; | ||
|
||
$temperature = null; | ||
if (isset($input['temperature']) | ||
&& (is_float($input['temperature']) || is_int($input['temperature']))) { | ||
$temperature = $input['temperature']; | ||
} | ||
|
||
$maxTokens = null; | ||
if (isset($input['max_tokens']) && is_int($input['max_tokens'])) { | ||
$maxTokens = $input['max_tokens']; | ||
} | ||
|
||
$extraParams = $temperature === null | ||
? null | ||
: ['temperature' => $temperature]; | ||
|
||
try { | ||
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) { | ||
$completion = $this->openAiAPIService->createChatCompletion($this->userId, $prompt, 1, $adminModel, $maxTokens, $extraParams); | ||
} else { | ||
$completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, $maxTokens, $extraParams); | ||
} | ||
} catch (Exception $e) { | ||
throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage()); | ||
} | ||
if (count($completion) > 0) { | ||
$endTime = time(); | ||
$this->openAiAPIService->updateExpTextProcessingTime($endTime - $startTime); | ||
return ['output' => array_pop($completion)]; | ||
} | ||
|
||
throw new RuntimeException('No result in OpenAI/LocalAI response.'); | ||
} | ||
} |