Skip to content

Commit

Permalink
Merge pull request #156 from nextcloud/fix/noid/chat-polling-switch-s…
Browse files Browse the repository at this point in the history
…ession

Fix polling new chat message when switching sessions
  • Loading branch information
julien-nc authored Nov 25, 2024
2 parents 38c1277 + 23eaccd commit b88fc93
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 36 deletions.
1 change: 1 addition & 0 deletions appinfo/routes.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
['name' => 'chattyLLM#getMessages', 'url' => '/chat/messages', 'verb' => 'GET'],
['name' => 'chattyLLM#generateForSession', 'url' => '/chat/generate', 'verb' => 'GET'],
['name' => 'chattyLLM#regenerateForSession', 'url' => '/chat/regenerate', 'verb' => 'GET'],
['name' => 'chattyLLM#checkSession', 'url' => '/chat/check_session', 'verb' => 'GET'],
['name' => 'chattyLLM#checkMessageGenerationTask', 'url' => '/chat/check_generation', 'verb' => 'GET'],
['name' => 'chattyLLM#generateTitle', 'url' => '/chat/generate_title', 'verb' => 'GET'],
['name' => 'chattyLLM#checkTitleGenerationTask', 'url' => '/chat/check_title_generation', 'verb' => 'GET'],
Expand Down
2 changes: 2 additions & 0 deletions lib/AppInfo/Application.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use OCA\Assistant\Capabilities;
use OCA\Assistant\Listener\BeforeTemplateRenderedListener;
use OCA\Assistant\Listener\ChattyLLMTaskListener;
use OCA\Assistant\Listener\CSPListener;
use OCA\Assistant\Listener\FreePrompt\FreePromptReferenceListener;
use OCA\Assistant\Listener\SpeechToText\SpeechToTextReferenceListener;
Expand Down Expand Up @@ -55,6 +56,7 @@ public function register(IRegistrationContext $context): void {

$context->registerEventListener(TaskSuccessfulEvent::class, TaskSuccessfulListener::class);
$context->registerEventListener(TaskFailedEvent::class, TaskFailedListener::class);
$context->registerEventListener(TaskSuccessfulEvent::class, ChattyLLMTaskListener::class);

$context->registerNotifierService(Notifier::class);

Expand Down
89 changes: 82 additions & 7 deletions lib/Controller/ChattyLLMController.php
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ public function generateForSession(int $sessionId): JSONResponse {
. PHP_EOL
. 'assistant: ';

$taskId = $this->scheduleLLMTask($stichedPrompt);
try {
$taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId);
} catch (\Exception $e) {
return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}

return new JSONResponse(['taskId' => $taskId]);
}
Expand Down Expand Up @@ -374,7 +378,7 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
$this->messageMapper->insert($message);
// do not insert here, it is done by the listener
return new JSONResponse($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->warning('Failed to add a chat message into DB', ['exception' => $e]);
Expand All @@ -388,6 +392,56 @@ public function checkMessageGenerationTask(int $taskId, int $sessionId): JSONRes
return new JSONResponse(['error' => 'unknown_error', 'task_status' => $task->getstatus()], Http::STATUS_BAD_REQUEST);
}

/**
* Check the status of a session
*
* Used by the frontend to determine if it should poll a generation task status.
*
* @param int $sessionId
* @return JSONResponse
* @throws \JsonException
* @throws \OCP\DB\Exception
*/
#[NoAdminRequired]
public function checkSession(int $sessionId): JSONResponse {
if ($this->userId === null) {
return new JSONResponse(['error' => $this->l10n->t('User not logged in')], Http::STATUS_UNAUTHORIZED);
}

$sessionExists = $this->sessionMapper->exists($this->userId, $sessionId);
if (!$sessionExists) {
return new JSONResponse(['error' => $this->l10n->t('Session not found')], Http::STATUS_NOT_FOUND);
}

try {
$messageTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-llm:' . $sessionId);
$titleTasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', 'chatty-title:' . $sessionId);
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
return new JSONResponse(['error' => 'task_query_failed'], Http::STATUS_BAD_REQUEST);
}
$messageTasks = array_filter($messageTasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
$titleTasks = array_filter($titleTasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
$session = $this->sessionMapper->getUserSession($this->userId, $sessionId);
$responseData = [
'messageTaskId' => null,
'titleTaskId' => null,
'sessionTitle' => $session->getTitle(),
];
if (!empty($messageTasks)) {
$task = array_pop($messageTasks);
$responseData['messageTaskId'] = $task->getId();
}
if (!empty($titleTasks)) {
$task = array_pop($titleTasks);
$responseData['titleTaskId'] = $task->getId();
}
return new JSONResponse($responseData);
}

/**
* Schedule a task to generate a title for the chat session
*
Expand Down Expand Up @@ -430,7 +484,11 @@ public function generateTitle(int $sessionId): JSONResponse {
. PHP_EOL . PHP_EOL
. $userInstructions;

$taskId = $this->scheduleLLMTask($stichedPrompt);
try {
$taskId = $this->scheduleLLMTask($stichedPrompt, $sessionId, false);
} catch (\Exception $e) {
return new JSONResponse(['error' => $e->getMessage()], Http::STATUS_BAD_REQUEST);
}
return new JSONResponse(['taskId' => $taskId]);
} catch (\OCP\DB\Exception $e) {
$this->logger->warning('Failed to generate a title for the chat session', ['exception' => $e]);
Expand Down Expand Up @@ -475,8 +533,7 @@ public function checkTitleGenerationTask(int $taskId, int $sessionId): JSONRespo
$title = str_replace('"', '', $title);
$title = explode(PHP_EOL, $title)[0];
$title = trim($title);

$this->sessionMapper->updateSessionTitle($this->userId, $sessionId, $title);
// do not write the title here since it's done in the listener

return new JSONResponse(['result' => $title]);
} catch (\OCP\DB\Exception $e) {
Expand Down Expand Up @@ -525,14 +582,32 @@ private function getStichedMessages(int $sessionId): string {
* Schedule the LLM task
*
* @param string $content
* @param int $sessionId
* @param bool $isMessage
* @return int|null
* @throws Exception
* @throws PreConditionNotMetException
* @throws UnauthorizedException
* @throws ValidationException
* @throws \JsonException
*/
private function scheduleLLMTask(string $content): ?int {
$task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId);
private function scheduleLLMTask(string $content, int $sessionId, bool $isMessage = true): ?int {
$customId = ($isMessage
? 'chatty-llm:'
: 'chatty-title:') . $sessionId;
try {
$tasks = $this->taskProcessingManager->getUserTasksByApp($this->userId, Application::APP_ID . ':chatty-llm', $customId);
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
throw new \Exception('task_query_failed');
}
$tasks = array_filter($tasks, static function (Task $task) {
return $task->getStatus() === Task::STATUS_RUNNING || $task->getStatus() === Task::STATUS_SCHEDULED;
});
// prevent scheduling multiple llm tasks simultaneously for one session
if (!empty($tasks)) {
throw new \Exception('session_already_thinking');
}
$task = new Task(TextToText::ID, ['input' => $content], Application::APP_ID . ':chatty-llm', $this->userId, $customId);
$this->taskProcessingManager->scheduleTask($task);
return $task->getId();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Db/ChattyLLM/Session.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
/**
* @method \string getUserId()
* @method \void setUserId(string $userId)
* @method \?string getTitle()
* @method \string|null getTitle()
* @method \void setTitle(?string $title)
* @method \int|null getTimestamp()
* @method \void setTimestamp(?int $timestamp)
Expand Down
21 changes: 21 additions & 0 deletions lib/Db/ChattyLLM/SessionMapper.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

namespace OCA\Assistant\Db\ChattyLLM;

use OCP\AppFramework\Db\DoesNotExistException;
use OCP\AppFramework\Db\MultipleObjectsReturnedException;
use OCP\AppFramework\Db\QBMapper;
use OCP\DB\Exception;
use OCP\DB\QueryBuilder\IQueryBuilder;
use OCP\IDBConnection;

Expand Down Expand Up @@ -59,6 +62,24 @@ public function exists(string $userId, int $sessionId): bool {
}
}

/**
* @param string $userId
* @param int $sessionId
* @return Session
* @throws DoesNotExistException
* @throws MultipleObjectsReturnedException
* @throws Exception
*/
public function getUserSession(string $userId, int $sessionId): Session {
$qb = $this->db->getQueryBuilder();
$qb->select('id', 'title', 'timestamp')
->from($this->getTableName())
->where($qb->expr()->eq('id', $qb->createPositionalParameter($sessionId, IQueryBuilder::PARAM_INT)))
->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId, IQueryBuilder::PARAM_STR)));

return $this->findEntity($qb);
}

/**
* @param string $userId
* @return array
Expand Down
64 changes: 64 additions & 0 deletions lib/Listener/ChattyLLMTaskListener.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?php

declare(strict_types=1);

namespace OCA\Assistant\Listener;

use OCA\Assistant\AppInfo\Application;
use OCA\Assistant\Db\ChattyLLM\Message;
use OCA\Assistant\Db\ChattyLLM\MessageMapper;
use OCA\Assistant\Db\ChattyLLM\SessionMapper;
use OCP\EventDispatcher\Event;
use OCP\EventDispatcher\IEventListener;
use OCP\TaskProcessing\Events\TaskSuccessfulEvent;
use Psr\Log\LoggerInterface;

/**
* @template-implements IEventListener<TaskSuccessfulEvent>
*/
class ChattyLLMTaskListener implements IEventListener {

public function __construct(
private MessageMapper $messageMapper,
private SessionMapper $sessionMapper,
private LoggerInterface $logger,
) {
}

public function handle(Event $event): void {
if (!($event instanceof TaskSuccessfulEvent)) {
return;
}

$task = $event->getTask();
$customId = $task->getCustomId();
$appId = $task->getAppId();

if ($customId === null || $appId !== (Application::APP_ID . ':chatty-llm')) {
return;
}

// title generation
if (preg_match('/^chatty-title:(\d+)$/', $customId, $matches)) {
$sessionId = (int)$matches[1];
$title = trim($task->getOutput()['output'] ?? '');
$this->sessionMapper->updateSessionTitle($task->getUserId(), $sessionId, $title);
}

// message generation
if (preg_match('/^chatty-llm:(\d+)$/', $customId, $matches)) {
$sessionId = (int)$matches[1];

$message = new Message();
$message->setSessionId($sessionId);
$message->setRole('assistant');
$message->setContent(trim($task->getOutput()['output'] ?? ''));
$message->setTimestamp(time());
try {
$this->messageMapper->insert($message);
} catch (\OCP\DB\Exception $e) {
$this->logger->error('Message insertion error in chattyllm task listener', ['exception' => $e]);
}
}
}
}
Loading

0 comments on commit b88fc93

Please sign in to comment.