Skip to content

Commit

Permalink
Merge pull request #167 from nextcloud/enh/noid/chat-with-tools
Browse files Browse the repository at this point in the history
Add ChatWithTools provider
  • Loading branch information
marcelklehr authored Dec 19, 2024
2 parents 12da32f + bdec92c commit 9150fc9
Show file tree
Hide file tree
Showing 22 changed files with 249 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/psalm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
fail-fast: false
matrix:
ocp-version: [ 'dev-master' ]
php-version: [ '8.0', '8.1', '8.2', '8.3' ]
php-version: [ '8.1', '8.2', '8.3' ]


name: Psalm check on PHP ${{ matrix.php-version }} and OCP ${{ matrix.ocp-version }}
Expand Down
1 change: 1 addition & 0 deletions appinfo/routes.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* Nextcloud - OpenAI
*
Expand Down
2 changes: 1 addition & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"composer/package-versions-deprecated": true
},
"platform": {
"php": "8.0"
"php": "8.1"
}
}
}
20 changes: 12 additions & 8 deletions composer.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions lib/AppInfo/Application.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* Nextcloud - OpenAI
*
Expand Down Expand Up @@ -99,6 +100,9 @@ public function register(IRegistrationContext $context): void {
$context->registerTaskProcessingProvider(ReformulateProvider::class);
$context->registerTaskProcessingTaskType(ChangeToneTaskType::class);
$context->registerTaskProcessingProvider(ChangeToneProvider::class);
if (class_exists('OCP\\TaskProcessing\\TaskTypes\\TextToTextChatWithTools')) {
$context->registerTaskProcessingProvider(\OCA\OpenAi\TaskProcessing\TextToTextChatWithToolsProvider::class);
}
}
if ($this->appConfig->getValueString(Application::APP_ID, 't2i_provider_enabled', '1') === '1') {
$context->registerTaskProcessingProvider(TextToImageProvider::class);
Expand Down
1 change: 1 addition & 0 deletions lib/Controller/ConfigController.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* Nextcloud - OpenAI
*
Expand Down
1 change: 1 addition & 0 deletions lib/Controller/OpenAiAPIController.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* Nextcloud - OpenAI
*
Expand Down
2 changes: 2 additions & 0 deletions lib/OldProcessing/Translation/TranslationProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public function detectLanguage(string $text): ?string {
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($this->userId, $adminModel, $prompt, null, null, 1, 100);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, 100);
}
Expand Down Expand Up @@ -137,6 +138,7 @@ public function translate(?string $fromLanguage, string $toLanguage, string $tex

if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($this->userId, $adminModel, $prompt, null, null, 1, PHP_INT_MAX);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($this->userId, $prompt, 1, $adminModel, 4000);
}
Expand Down
78 changes: 61 additions & 17 deletions lib/Service/OpenAiAPIService.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* Nextcloud - OpenAI
*
Expand Down Expand Up @@ -350,24 +351,28 @@ public function createCompletion(
*
* @param string|null $userId
* @param string $model
* @param string $userPrompt
* @param string|null $userPrompt
* @param string|null $systemPrompt
* @param array|null $history
* @param int $n
* @param int|null $maxTokens
* @param array|null $extraParams
* @return string[]
* @param string|null $toolMessage JSON string with role, content, tool_call_id
* @param array|null $tools
* @return array<string, array<string>>
* @throws Exception
*/
public function createChatCompletion(
?string $userId,
string $model,
string $userPrompt,
?string $userPrompt = null,
?string $systemPrompt = null,
?array $history = null,
int $n = 1,
?int $maxTokens = null,
?array $extraParams = null,
?string $toolMessage = null,
?array $tools = null,
): array {
if ($this->isQuotaExceeded($userId, Application::QUOTA_TYPE_TEXT)) {
throw new Exception($this->l10n->t('Text generation quota exceeded'), Http::STATUS_TOO_MANY_REQUESTS);
Expand All @@ -384,28 +389,47 @@ public function createChatCompletion(
}
if ($history !== null) {
foreach ($history as $i => $historyEntry) {
if (str_starts_with($historyEntry, 'system:')) {
$historyEntry = preg_replace('/^system:/', '', $historyEntry);
$messages[] = ['role' => 'system', 'content' => $historyEntry];
} elseif (str_starts_with($historyEntry, 'user:')) {
$historyEntry = preg_replace('/^user:/', '', $historyEntry);
$messages[] = ['role' => 'user', 'content' => $historyEntry];
} elseif (((int)$i) % 2 === 0) {
// we assume even indexes are user messages and odd ones are system ones
$messages[] = ['role' => 'user', 'content' => $historyEntry];
} else {
$messages[] = ['role' => 'system', 'content' => $historyEntry];
$message = json_decode($historyEntry, true);
if ($message['role'] === 'human') {
$message['role'] = 'user';
}
if (isset($message['tool_calls']) && is_array($message['tool_calls'])) {
$message['tool_calls'] = array_map(static function ($toolCall) {
$formattedToolCall = [
'id' => $toolCall['id'],
'type' => 'function',
'function' => $toolCall,
];
$formattedToolCall['function']['arguments'] = json_encode($toolCall['args']);
unset($formattedToolCall['function']['id']);
unset($formattedToolCall['function']['args']);
unset($formattedToolCall['function']['type']);
return $formattedToolCall;
}, $message['tool_calls']);
}
$messages[] = $message;
}
}
if ($userPrompt !== null) {
$messages[] = ['role' => 'user', 'content' => $userPrompt];
}
if ($toolMessage !== null) {
$msgs = json_decode($toolMessage, true);
foreach ($msgs as $msg) {
$msg['role'] = 'tool';
$messages[] = $msg;
}
}
$messages[] = ['role' => 'user', 'content' => $userPrompt];

$params = [
'model' => $model === Application::DEFAULT_MODEL_ID ? Application::DEFAULT_COMPLETION_MODEL_ID : $model,
'messages' => $messages,
'max_tokens' => $maxTokens,
'n' => $n,
];
if ($tools !== null) {
$params['tools'] = $tools;
}
if ($userId !== null && $this->isUsingOpenAi()) {
$params['user'] = $userId;
}
Expand Down Expand Up @@ -434,10 +458,30 @@ public function createChatCompletion(
$this->logger->warning('Could not create quota usage for user: ' . $userId . ' and quota type: ' . Application::QUOTA_TYPE_TEXT . '. Error: ' . $e->getMessage(), ['app' => Application::APP_ID]);
}
}
$completions = [];
$completions = [
'messages' => [],
'tool_calls' => [],
];

foreach ($response['choices'] as $choice) {
$completions[] = $choice['message']['content'];
// get tool calls only if this is the finish reason and it's defined and it's an array
if ($choice['finish_reason'] === 'tool_calls'
&& isset($choice['message']['tool_calls'])
&& is_array($choice['message']['tool_calls'])
) {
// fix the tool_calls format, make it like expected by the context_agent app
$choice['message']['tool_calls'] = array_map(static function ($toolCall) {
$toolCall['function']['id'] = $toolCall['id'];
$toolCall['function']['args'] = json_decode($toolCall['function']['arguments']);
unset($toolCall['function']['arguments']);
return $toolCall['function'];
}, $choice['message']['tool_calls']);
$completions['tool_calls'][] = json_encode($choice['message']['tool_calls']);
}
// always try to get a message
if (isset($choice['message']['content']) && is_string($choice['message']['content'])) {
$completions['messages'][] = $choice['message']['content'];
}
}

return $completions;
Expand Down
1 change: 1 addition & 0 deletions lib/Service/OpenAiSettingsService.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?php

/**
* @copyright Copyright (c) 2023, Sami Finnilä ([email protected])
*
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/ChangeToneProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public function process(?string $userId, array $input, callable $reportProgress)
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/ContextWriteProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public function process(?string $userId, array $input, callable $reportProgress)
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/HeadlineProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress)
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/ReformulateProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress)
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/SummaryProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ public function process(?string $userId, array $input, callable $reportProgress)
try {
if ($this->openAiAPIService->isUsingOpenAi() || $this->openAiSettingsService->getChatEndpointEnabled()) {
$completion = $this->openAiAPIService->createChatCompletion($userId, $model, $prompt, null, null, 1, $maxTokens);
$completion = $completion['messages'];
} else {
$completion = $this->openAiAPIService->createCompletion($userId, $prompt, 1, $model, $maxTokens);
}
Expand Down
1 change: 1 addition & 0 deletions lib/TaskProcessing/TextToTextChatProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public function process(?string $userId, array $input, callable $reportProgress)

try {
$completion = $this->openAiAPIService->createChatCompletion($userId, $adminModel, $userPrompt, $systemPrompt, $history, 1, $maxTokens);
$completion = $completion['messages'];
} catch (Exception $e) {
throw new RuntimeException('OpenAI/LocalAI request failed: ' . $e->getMessage());
}
Expand Down
Loading

0 comments on commit 9150fc9

Please sign in to comment.