Skip to content

Commit

Permalink
improve image generation model endpoint generation and realted UX
Browse files Browse the repository at this point in the history
  • Loading branch information
yomybaby authored and agatha197 committed Jan 9, 2025
1 parent 12a8290 commit 0023202
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 12 deletions.
10 changes: 9 additions & 1 deletion react/src/components/ChatContent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import { useLazyLoadQuery } from 'react-relay/hooks';
interface ChatContentProps {
endpointId: string;
endpointUrl: string;
endpointName: string;
basePath: string;
}

const ChatContent: React.FC<ChatContentProps> = ({
endpointId,
endpointUrl,
endpointName,
basePath,
}) => {
const { t } = useTranslation();
Expand Down Expand Up @@ -56,6 +58,7 @@ const ChatContent: React.FC<ChatContentProps> = ({
fetchPolicy: 'network-only',
},
);
const isTextToImageModel = _.includes(endpointName, 'stable-diffusion');

const newestValidToken =
_.orderBy(endpoint_token_list?.items, ['valid_until'], ['desc'])[0]
Expand Down Expand Up @@ -85,7 +88,12 @@ const ChatContent: React.FC<ChatContentProps> = ({
return (
<LLMChatCard
endpointId={endpointId || ''}
baseURL={new URL(basePath, endpointUrl).toString()}
baseURL={
isTextToImageModel
? new URL('/generate-image', endpointUrl || '').toString()
: new URL(basePath, endpointUrl || '').toString()
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand Down
1 change: 1 addition & 0 deletions react/src/components/ModelCardChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const ModelCardChat: React.FC<ModelCardChatProps> = ({
<ChatContent
endpointId={healthyEndpoint[0]?.endpoint_id as string}
endpointUrl={healthyEndpoint[0]?.url as string}
endpointName={healthyEndpoint[0]?.name as string}
basePath={basePath}
/>
) : (
Expand Down
2 changes: 1 addition & 1 deletion react/src/components/ModelCardModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ const ModelCardModal: React.FC<ModelCardModalProps> = ({
wrap="wrap"
align="stretch"
gap={'sm'}
style={{ width: '100%' }}
style={{ width: '100%', minHeight: '50vh' }}
>
<Flex
direction="row"
Expand Down
4 changes: 2 additions & 2 deletions react/src/components/lablupTalkativotUI/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ const ChatMessage: React.FC<{
src={attachment?.url}
alt={attachment?.name}
style={{
maxWidth: '50vw',
maxHeight: '12vh',
maxWidth: placement === 'left' ? 200 : 300,
maxHeight: placement === 'left' ? 200 : 300,
borderRadius: token.borderRadius,
}}
/>
Expand Down
12 changes: 11 additions & 1 deletion react/src/components/lablupTalkativotUI/ChatUIModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
graphql`
fragment ChatUIModalFragment on Endpoint {
endpoint_id
name
url
status
}
Expand All @@ -90,6 +91,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
`,
endpointTokenFrgmt,
);
const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const newestToken = _.maxBy(
endpointTokenList?.items,
Expand Down Expand Up @@ -124,7 +126,14 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
) : (
<LLMChatCard
endpointId={endpoint?.endpoint_id || ''}
baseURL={new URL(basePath, endpoint?.url || '').toString()}
baseURL={
endpoint?.url
? isTextToImageModel
? new URL('/generate-image', endpoint?.url || '').toString()
: new URL(basePath, endpoint?.url || '').toString()
: ''
}
isImageGeneration={isTextToImageModel}
models={_.map(modelsResult?.data, (m) => ({
id: m.id,
name: m.id,
Expand All @@ -133,6 +142,7 @@ const EndpointChatContent: React.FC<ChatUIBasicProps> = ({
style={{ flex: 1 }}
allowCustomModel={_.isEmpty(modelsResult?.data)}
alert={
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
Expand Down
13 changes: 10 additions & 3 deletions react/src/components/lablupTalkativotUI/EndpointLLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
fragment EndpointLLMChatCard_endpoint on Endpoint {
endpoint_id
url
name
}
`,
endpointFrgmt,
Expand All @@ -63,6 +64,8 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
chatSubmitKeyInfoState,
);

const isTextToImageModel = _.includes(endpoint?.name, 'stable-diffusion');

const { data: modelsResult } = useSuspenseTanQuery<{
data: Array<Model>;
}>({
Expand Down Expand Up @@ -93,9 +96,12 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
{...cardProps}
baseURL={
endpoint?.url
? new URL(basePath, endpoint?.url ?? undefined).toString()
: undefined
? isTextToImageModel
? new URL('/generate-image', endpoint?.url).toString()
: new URL(basePath, endpoint?.url).toString()
: ''
}
isImageGeneration={isTextToImageModel}
models={models}
fetchOnClient
leftExtra={
Expand Down Expand Up @@ -147,7 +153,8 @@ const EndpointLLMChatCard: React.FC<EndpointLLMChatCardProps> = ({
}
allowCustomModel={_.isEmpty(models)}
alert={
_.isEmpty(models) && (
!isTextToImageModel &&
_.isEmpty(modelsResult?.data) && (
<Alert
type="warning"
showIcon
Expand Down
24 changes: 20 additions & 4 deletions react/src/components/lablupTalkativotUI/LLMChatCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
setLoadingImageGeneration(true);
try {
const response = await fetch(
'https://stable-diffusion-3m.asia03.app.backend.ai/generate-image',
customModelFormRef.current?.getFieldValue('baseURL'),
{
method: 'POST',
headers: {
Expand All @@ -244,7 +244,9 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
);
if (response.ok) {
const responseData = await response.json();
return 'data:image/png;base64,' + responseData.image_base64;
return _.startsWith(responseData.image_base64, 'data:image/png;base64,')
? responseData.image_base64
: 'data:image/png;base64,' + responseData.image_base64;
} else {
throw new Error('Error generating image');
}
Expand Down Expand Up @@ -414,8 +416,8 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
});

if (isImageGeneration) {
const generationId = _.uniqueId();
try {
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
...prevMessages,
{
Expand All @@ -424,7 +426,20 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
content: input,
},
{
id: _.uniqueId(),
id: generationId,
role: 'assistant',
content: 'Processing...',
},
]);
setInput('');
const imageBase64 = await generateImage(input, 'accessKey');
setMessages((prevMessages) => [
..._.filter(
prevMessages,
(message) => message.id !== generationId,
),
{
id: generationId,
role: 'assistant',
content: '',
experimental_attachments: [
Expand Down Expand Up @@ -510,6 +525,7 @@ const LLMChatCard: React.FC<LLMChatCardProps> = ({
required: true,
},
]}
hidden={isImageGeneration}
>
<Input placeholder="llm-model" />
</Form.Item>
Expand Down

0 comments on commit 0023202

Please sign in to comment.