Skip to content

Commit

Permalink
make prompt req, add a few options
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmatthis committed Dec 20, 2024
1 parent 8c9c121 commit 2e57039
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 61 deletions.
10 changes: 6 additions & 4 deletions src/core/ai/openai/openai-image.service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { OpenAI } from 'openai';
import { ImageGenerationDto } from './dto/image-generation.dto'; // Ensure this path is correct
import { ImageGenerationDto } from './dto/image-generation.dto';
import { OpenaiSecretsService } from './openai-secrets.service';
import { ImagesResponse } from 'openai/resources';

Expand All @@ -23,7 +23,9 @@ export class OpenaiImageService implements OnModuleInit {
}
}

public async generateImage(dto: ImageGenerationDto) {
public async generateImage(
dto: ImageGenerationDto,
): Promise<ImagesResponse | Error> {
try {
const {
prompt,
Expand All @@ -36,7 +38,7 @@ export class OpenaiImageService implements OnModuleInit {
style = 'vivid',
} = dto;

const generationResponse: ImagesResponse =
const generationResponse: ImagesResponse | Error =
await this.openai.images.generate({
prompt,
model,
Expand All @@ -52,7 +54,7 @@ export class OpenaiImageService implements OnModuleInit {
return generationResponse;
} catch (error) {
this._logger.error('Failed to generate image.', error);
throw error;
return error;
}
}
}
156 changes: 99 additions & 57 deletions src/interfaces/discord/commands/discord-image.command.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
import { Injectable } from '@nestjs/common';
import {
BooleanOption,
Context,
Options,
SlashCommand,
SlashCommandContext,
StringOption,
} from 'necord';
import { Context, Options, SlashCommand, SlashCommandContext } from 'necord';
import { OpenaiImageService } from '../../../core/ai/openai/openai-image.service';
import OpenAI from 'openai';
import { AttachmentBuilder } from 'discord.js';
import ImagesResponse = OpenAI.ImagesResponse;
import {
AttachmentBuilder,
CacheType,
ChatInputCommandInteraction,
} from 'discord.js';
import { OpenaiTextGenerationService } from '../../../core/ai/openai/openai-text.service';

export class ImagePromptDto {
@StringOption({
name: 'prompt',
description: 'Starting text for the chat',
required: false,
})
prompt: string = 'Generate a new image';

@BooleanOption({
name: 'use_context',
description:
'Whether to include text from this Thread/Channel in the image generation prompt',
required: false,
})
useContext: boolean;
}
import { ImagePromptDto } from './image-prompt.dto';
import ImagesResponse = OpenAI.ImagesResponse;

@Injectable()
export class DiscordImageCommand {
Expand All @@ -44,51 +25,112 @@ export class DiscordImageCommand {
})
public async handleImageCommand(
@Context() [interaction]: SlashCommandContext,
@Options({ required: false }) imagePromptDto?: ImagePromptDto,
@Options({ required: true }) imagePromptDto?: ImagePromptDto,
) {
await interaction.deferReply();
let promptText = '';
if (!imagePromptDto || !imagePromptDto.prompt) {
promptText = 'Generate a new image';
} else {
promptText = imagePromptDto.prompt;
}
if (imagePromptDto && imagePromptDto.useContext) {
const context = interaction.channel;
const messages = await context.messages.fetch();
const contextText = messages
.map((message) => message.content)
.join(' \n ');
const promptText = await this.generatePrompt(imagePromptDto, interaction);
const initialMessageText = await this.sendInitialReply(
promptText,
interaction,
);

let promptInstructions =
'Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt';
if (imagePromptDto && imagePromptDto.prompt) {
promptInstructions = imagePromptDto.prompt;
}
promptText = await this._openaiTextService.generateText({
prompt: `${promptInstructions}.\n\n--------BEGIN INPUT TEXT\n\n ${contextText} \n\n ---------------END OF INPUT TEXT\n\nREMEMBER! Your task is toyeah, ${promptInstructions}.`,
model: 'gpt-4o',
temperature: 0.5,
max_tokens: 300,
});
}
await interaction.editReply({
content: `Generating image from prompt:\n > ${promptText} \n\n Please wait...`,
});
// generate image
const response: ImagesResponse =
const response: ImagesResponse | Error =
await this._openaiImageService.generateImage({
prompt: promptText,
user: interaction.user.id,
style: imagePromptDto.naturalStyle ? 'natural' : 'vivid',
});

if (response instanceof Error) {
await interaction.editReply({
content: `Error generating image for prompt ${promptText}: \n Error response:\n\n ${response.message}`,
});
return;
}

if (!response.data || response.data.length === 0) {
await interaction.editReply({
content: `No image was generated from the prompt:\n > ${promptText} `,
});
return;
}

const imageBuffer = Buffer.from(response.data[0].b64_json, 'base64');
const imageAttachment = new AttachmentBuilder(imageBuffer, {
name: 'image.png',
});

await interaction.editReply({
content: `Image generated from prompt:\n > ${promptText}`,
content: `${initialMessageText} Revised Prompt: \n > ${response.data[0].revised_prompt}`,
files: [imageAttachment],
});
}

private async sendInitialReply(
promptText: string,
interaction: ChatInputCommandInteraction<CacheType>,
) {
const pleaseWaitText = 'Generating Image, Please wait...';
const initialMessageText = `Original Prompt:\n > ${promptText} \n\n`;
await interaction.editReply({
content: `${initialMessageText} ${pleaseWaitText}`,
});
return initialMessageText;
}

private async generatePrompt(
imagePromptDto: ImagePromptDto,
interaction: ChatInputCommandInteraction<CacheType>,
) {
let promptText = imagePromptDto.prompt;
let contextText = '';
if (imagePromptDto && imagePromptDto.useContext) {
contextText = await this.get_context_text(interaction, imagePromptDto);
promptText = `${promptText}\n${contextText} \n ${promptText}`;
}

if (promptText.length > 200) {
promptText = await this.condenseText(contextText);
}

if (imagePromptDto.useExactPrompt) {
// the instructions on how to disable the revise prompt feature literally say to append this text to the prompt lol: https://platform.openai.com/docs/guides/images#prompting
const dontReviseTextPrompt =
'I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS:';
promptText = `${dontReviseTextPrompt} \n ${promptText}`;
}
return promptText;
}

private async condenseText(contextText: string) {
const promptInstructions =
'Condense the provided INPUT TEXT into a 200 word (or less) prompt that will be used to generate an image. Do not generate any text other than the image generation prompt';

return await this._openaiTextService.generateText({
prompt: `${promptInstructions}.\n\n--------BEGIN INPUT TEXT\n\n ${contextText} \n\n ---------------END OF INPUT TEXT\n\nREMEMBER! Your task is toyeah, ${promptInstructions}.`,
model: 'gpt-4o',
temperature: 0.5,
max_tokens: 300,
});
}

private async get_context_text(
interaction: ChatInputCommandInteraction<CacheType>,
imagePromptDto: ImagePromptDto,
) {
const context = interaction.channel;
const contextLength = imagePromptDto.contextLength || 5;
const messages = await context.messages.fetch();

if (contextLength === -1) {
return messages.map((message) => message.content).join(' \n ');
} else {
const messageArray = Array.from(messages.values());
return messageArray
.slice(0, contextLength)
.map((message) => message.content)
.join(' \n ');
}
}
}
41 changes: 41 additions & 0 deletions src/interfaces/discord/commands/image-prompt.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { BooleanOption, StringOption } from 'necord';

export class ImagePromptDto {
@StringOption({
name: 'prompt',
description: 'Starting text for the chat',
required: true,
})
prompt: string = '';

@BooleanOption({
name: 'use_context',
description:
'Whether to include text from this Thread/Channel in the image generation prompt',
required: false,
})
useContext: boolean;

@BooleanOption({
name: 'natural_style',
description: 'Set `True` to use the `natural` style (default is `vivid`)',
required: false,
})
naturalStyle: boolean;

@BooleanOption({
name: 'use_exact_prompt',
description:
'Set `True` to use the exact prompt provided without any modifications',
required: false,
})
useExactPrompt: boolean;

@StringOption({
name: 'context_length',
description:
'(if `use_context`) #messages to use in the context window, -1 for all (default: 5)',
required: false,
})
contextLength?: number;
}

0 comments on commit 2e57039

Please sign in to comment.