Skip to content

Commit

Permalink
improve image generation model endpoint generation and realted UX
Browse files Browse the repository at this point in the history
  • Loading branch information
yomybaby committed Jan 6, 2025
1 parent 3ec5035 commit 10d365d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
10 changes: 9 additions & 1 deletion react/src/components/lablupTalkativotUI/ChatUIModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
graphql`
fragment ChatUIModalFragment on Endpoint {
endpoint_id
name
url
status
}
Expand All @@ -90,6 +91,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
`,
endpointTokenFrgmt,
);
const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const newestToken = _.maxBy(
endpointTokenList?.items,
Expand Down Expand Up @@ -124,7 +126,12 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
) : (
<LLMChatCard
endpointId={endpoint?.endpoint_id || ''}
baseURL={new URL(basePath, endpoint?.url || '').toString()}
baseURL={
isTextToImageModel
? new URL('/generate-image', endpoint?.url || '').toString()
: new URL(basePath, endpoint?.url || '').toString()
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand All @@ -133,6 +140,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
style={{ flex: 1 }}
allowCustomModel={_.isEmpty(modelsResult?.data)}
alert={
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
Expand Down
20 changes: 17 additions & 3 deletions react/src/components/lablupTalkativotUI/LLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
setLoadingImageGeneration(true);
try {
const response = await fetch(
'https://stable-diffusion-3m.asia03.app.backend.ai/generate-image',
customModelFormRef.current?.getFieldValue('baseURL'),
{
method: 'POST',
headers: {
Expand Down Expand Up @@ -414,8 +414,8 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
});

if (isImageGeneration) {
const generationId = _.uniqueId();
try {
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
...prevMessages,
{
Expand All @@ -424,7 +424,20 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
content: input,
},
{
id: _.uniqueId(),
id: generationId,
role: 'assistant',
content: 'Processing...',
},
]);
setInput('');
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
..._.filter(
prevMessages,
(message) => message.id !== generationId,
),
{
id: generationId,
role: 'assistant',
content: '',
experimental_attachments: [
Expand Down Expand Up @@ -510,6 +523,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
required: true,
},
]}
hidden={isImageGeneration}
>
<Input placeholder="llm-model" />
</Form.Item>
Expand Down

0 comments on commit 10d365d

Please sign in to comment.