diff --git a/src/core/ai/openai/openai-chat.service.ts b/src/core/ai/openai/openai-chat.service.ts index f1dc4a2..f5ab821 100644 --- a/src/core/ai/openai/openai-chat.service.ts +++ b/src/core/ai/openai/openai-chat.service.ts @@ -15,6 +15,7 @@ export interface OpenAiChatConfig { | 'gpt-3.5-turbo-16k'; temperature: number; stream: boolean; + max_tokens: number; } @Injectable() @@ -62,10 +63,19 @@ export class OpenaiChatService implements OnModuleInit { this._storeConfig(chatId, config); } - public getAiResponseStream(chatId: string, humanMessage: string) { + public getAiResponseStream( + chatId: string, + humanMessage: string, + imageURLs: string[], + ) { this.logger.debug(`Getting AI response stream for chatId: ${chatId}`); const config = this._getConfigOrThrow(chatId); - config.messages.push({ role: 'user', content: humanMessage }); + const messageContent: any[] = [{ type: 'text', text: humanMessage }]; + for (const imageURL of imageURLs) { + messageContent.push({ type: 'image_url', image_url: imageURL }); + } + + config.messages.push({ role: 'user', content: messageContent }); return this.streamResponse(config); } @@ -85,14 +95,15 @@ export class OpenaiChatService implements OnModuleInit { // @ts-ignore for await (const newChunk of chatStream) { allStreamedChunks.push(newChunk); - fullAiResponseText += newChunk.choices[0].delta.content; - const chunkText = newChunk.choices[0].delta.content; + fullAiResponseText += newChunk.choices[0].delta.content || ''; + const chunkText = newChunk.choices[0].delta.content || ''; if (chunkText) { chunkToYield += chunkText; } if ( chunkToYield.length >= yieldAtLength || - newChunk.choices[0].finish_reason === 'stop' + newChunk.choices[0].finish_reason === 'stop' || + newChunk.choices[0].finish_reason === 'length' ) { this.logger.debug(`Streaming text chunk: ${chunkToYield}`); yield chunkToYield; @@ -119,9 +130,10 @@ export class OpenaiChatService implements OnModuleInit { private _defaultChatConfig() { return { messages: [], - model: 'gpt-4-1106-preview', + model: 'gpt-4-vision-preview', temperature: 0.7, stream: true, + max_tokens: 4096, } as OpenAiChatConfig; } private _reloadMessageHistoryFromAiChatDocument(aiChat: AiChatDocument) { diff --git a/src/interfaces/discord/services/discord-attachment.service.ts b/src/interfaces/discord/services/discord-attachment.service.ts index 37130d0..b4135dc 100644 --- a/src/interfaces/discord/services/discord-attachment.service.ts +++ b/src/interfaces/discord/services/discord-attachment.service.ts @@ -112,6 +112,32 @@ export class DiscordAttachmentService { } } + async getImageDataFromURL(url: string) { + try { + this.logger.log(`getting Image from URL: ${url}`); + + // Fetching the image data as a stream + const response = await axios({ + method: 'get', + url: url, + responseType: 'arraybuffer', + }); + // Converting the image data to base64 + const imageBase64 = Buffer.from(response.data, 'binary').toString( + 'base64', + ); + + // Getting the content type of the image + const contentType = response.headers['content-type']; + + // Combining content type and base64 encoding for complete image data + return `data:${contentType};base64,${imageBase64}`; + } catch (error) { + this.logger.error(`Error getting image data from URL: ${error}`); + throw error; + } + } + private async handleTextAttachment( tempFilePath: string, attachment: Attachment, diff --git a/src/interfaces/discord/services/discord-message.service.ts b/src/interfaces/discord/services/discord-message.service.ts index 3e79bca..04abfab 100644 --- a/src/interfaces/discord/services/discord-message.service.ts +++ b/src/interfaces/discord/services/discord-message.service.ts @@ -28,11 +28,13 @@ export class DiscordMessageService { try { let humanInputText = ''; let attachmentText = ''; + let imageURLs: string[] = []; if (!textToRespondTo) { - ({ humanInputText, attachmentText } = await this.extractMessageContent( - discordMessage, - respondToChannelOrMessage, - )); + ({ humanInputText, attachmentText, imageURLs } = + await this.extractMessageContent( + discordMessage, + respondToChannelOrMessage, + )); } else { humanInputText = textToRespondTo; attachmentText = ''; @@ -48,6 +50,7 @@ export class DiscordMessageService { humanUserId, humanInputText, attachmentText, + imageURLs, discordMessage, isFirstExchange, respondToChannelOrMessage, @@ -103,6 +106,7 @@ export class DiscordMessageService { humanUserId: string, inputMessageText: string, attachmentText: string, + imageURLs: string[], discordMessage: Message, isFirstExchange: boolean = false, respondToChannelOrMessage: Message | TextBasedChannel, @@ -114,6 +118,7 @@ export class DiscordMessageService { const aiResponseStream = this._openaiChatService.getAiResponseStream( discordMessage.channel.id, inputMessageText + attachmentText, + imageURLs, ); const maxMessageLength = 2000 * 0.9; // discord max message length is 2000 characters (and *0.9 to be safe) @@ -182,7 +187,7 @@ export class DiscordMessageService { isFirstExchange, ); } catch (error) { - this.logger.error(`Error in _handleStream: ${error}`); + this.logger.error(`${error}`); } } @@ -192,6 +197,7 @@ export class DiscordMessageService { ) { let humanInputText = discordMessage.content; let attachmentText = ''; + const imageURLs = []; if (discordMessage.attachments.size > 0) { if (humanInputText.length > 0) { humanInputText = @@ -199,8 +205,19 @@ export class DiscordMessageService { humanInputText + '\n\nEND TEXT FROM HUMAN INPUT\n\n'; } - attachmentText = 'BEGIN TEXT FROM ATTACHMENTS:\n\n'; for (const [, attachment] of discordMessage.attachments) { + if (attachment.contentType.split('/')[0] == 'image') { + imageURLs.push( + await this._discordAttachmentService.getImageDataFromURL( + attachment.url, //.split('?')[0] + ), + ); + this.logger.debug('pushed img url to attachmentURLs'); + continue; + } + if (!attachmentText) { + attachmentText = 'BEGIN TEXT FROM ATTACHMENTS:\n\n'; + } const attachmentResponse = await this._discordAttachmentService.handleAttachment(attachment); attachmentText += attachmentResponse.text; @@ -229,7 +246,7 @@ export class DiscordMessageService { attachmentText += 'END TEXT FROM ATTACHMENTS'; } } - return { humanInputText, attachmentText }; + return { humanInputText, attachmentText, imageURLs }; } private async _sendFullResponseAsAttachment( diff --git a/src/interfaces/discord/services/discord-on-message.service.ts b/src/interfaces/discord/services/discord-on-message.service.ts index 84e2a2f..6ec7e4f 100644 --- a/src/interfaces/discord/services/discord-on-message.service.ts +++ b/src/interfaces/discord/services/discord-on-message.service.ts @@ -47,9 +47,10 @@ export class DiscordOnMessageService { const chatConfig = { messages: [], - model: 'gpt-4-1106-preview', + model: 'gpt-4-vision-preview', temperature: 0.7, stream: true, + max_tokens: 4096, } as OpenAiChatConfig; this._openaiChatService.createChat(aiChatId, contextPrompt, chatConfig); const aiChatDocument = await this._aiChatsService.createAiChat({ @@ -58,7 +59,7 @@ export class DiscordOnMessageService { contextRoute, contextInstructions: contextPrompt, couplets: [], - modelName: 'gpt-4-1106-preview', + modelName: 'gpt-4-vision-preview', }); this.logger.debug(`Adding threadId ${aiChatId} to active listeners`); @@ -122,7 +123,7 @@ export class DiscordOnMessageService { contextInstructions: await this._contextPromptService.getContextPromptFromMessage(message), couplets: [], - modelName: 'gpt-4-1106-preview', + modelName: 'gpt-4-vision-preview', }, populateCouplets, );