Skip to content

Commit

Permalink
feat(client): stream queue status (#71)
Browse files Browse the repository at this point in the history
* feat(client): stream queue status

* chore: remove console log

* fix: accumulative logs when streaming

* fix(client): stream logs on queue update

* chore(apps): remove pollInterval from sample apps
  • Loading branch information
drochetti authored Jun 25, 2024
1 parent c791016 commit 4ea43b4
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ export default function ComfyImageToImagePage() {
prompt: prompt,
loadimage_1: imageFile,
},
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ export default function ComfyImageToVideoPage() {
input: {
loadimage_1: imageFile,
},
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export default function ComfyTextToImagePage() {
input: {
prompt: prompt,
},
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
1 change: 0 additions & 1 deletion apps/demo-nextjs-app-router/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ export default function Home() {
image_url: imageFile,
image_size: 'square_hd',
},
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
152 changes: 152 additions & 0 deletions apps/demo-nextjs-app-router/app/queue/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
'use client';

import * as fal from '@fal-ai/serverless-client';
import { useState } from 'react';

fal.config({
proxyUrl: '/api/fal/proxy',
});

type ErrorProps = {
error: any;
};

function Error(props: ErrorProps) {
if (!props.error) {
return null;
}
return (
<div
className="p-4 mb-4 text-sm text-red-800 rounded bg-red-50 dark:bg-gray-800 dark:text-red-400"
role="alert"
>
<span className="font-medium">Error</span> {props.error.message}
</div>
);
}

export default function Home() {
// Input state
const [endpointId, setEndpointId] = useState<string>('');
const [input, setInput] = useState<string>('{}');
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
const [result, setResult] = useState<any | null>(null);
const [logs, setLogs] = useState<string[]>([]);
const [elapsedTime, setElapsedTime] = useState<number>(0);

const reset = () => {
setLoading(false);
setError(null);
setResult(null);
setLogs([]);
setElapsedTime(0);
};

const run = async () => {
reset();
setLoading(true);
const start = Date.now();
try {
const result: any = await fal.subscribe(endpointId, {
input: JSON.parse(input),
logs: true,
onQueueUpdate(update) {
console.log('queue update');
console.log(update);
setElapsedTime(Date.now() - start);
if (
update.status === 'IN_PROGRESS' ||
update.status === 'COMPLETED'
) {
if (update.logs && update.logs.length > logs.length) {
setLogs((update.logs || []).map((log) => log.message));
}
}
},
});
setResult(result);
} catch (error: any) {
setError(error);
} finally {
setLoading(false);
setElapsedTime(Date.now() - start);
}
};
return (
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
<h1 className="text-4xl font-bold mb-8">
<code className="font-light text-pink-600">fal</code>
<code>queue</code>
</h1>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
Endpoint ID
</label>
<input
className="w-full text-base p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10"
id="endpointId"
name="endpointId"
autoComplete="off"
placeholder="Endpoint ID"
value={endpointId}
spellCheck={false}
onChange={(e) => setEndpointId(e.target.value)}
/>
</div>
<div className="text-lg w-full">
<label htmlFor="prompt" className="block mb-2 text-current">
JSON Input
</label>
<textarea
className="w-full text-sm p-2 rounded bg-black/10 dark:bg-white/5 border border-black/20 dark:border-white/10 font-mono"
id="input"
name="Input"
placeholder="JSON"
value={input}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInput(e.target.value)}
rows={6}
></textarea>
</div>

<button
onClick={(e) => {
e.preventDefault();
run();
}}
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold text-lg py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline"
disabled={loading}
>
{loading ? 'Running...' : 'Run'}
</button>

<Error error={error} />

<div className="w-full flex flex-col space-y-4">
<div className="space-y-2">
<h3 className="text-xl font-light">JSON Result</h3>
<p className="text-sm text-current/80">
{`Elapsed Time (seconds): ${(elapsedTime / 1000).toFixed(2)}`}
</p>
<pre className="text-sm bg-black/70 text-white/80 font-mono h-60 rounded whitespace-pre overflow-auto w-full">
{result
? JSON.stringify(result, null, 2)
: '// result pending...'}
</pre>
</div>

<div className="space-y-2">
<h3 className="text-xl font-light">Logs</h3>
<pre className="text-sm bg-black/70 text-white/80 font-mono h-60 rounded whitespace-pre overflow-auto w-full">
{logs.join('\n')}
</pre>
</div>
</div>
</main>
</div>
);
}
1 change: 0 additions & 1 deletion apps/demo-nextjs-app-router/app/whisper/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ export default function WhisperDemo() {
file_name: 'recording.wav',
audio_url: audioFile,
},
pollInterval: 1000,
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
1 change: 0 additions & 1 deletion apps/demo-nextjs-page-router/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ export function Index() {
model_name: 'stabilityai/stable-diffusion-xl-base-1.0',
image_size: 'square_hd',
},
pollInterval: 3000, // Default is 1000 (every 1s)
logs: true,
onQueueUpdate(update) {
setElapsedTime(Date.now() - start);
Expand Down
2 changes: 1 addition & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.11.0",
"version": "0.12.0",
"license": "MIT",
"repository": {
"type": "git",
Expand Down
105 changes: 63 additions & 42 deletions libs/client/src/function.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { getTemporaryAuthToken } from './auth';
import { dispatchRequest } from './request';
import { storageImpl } from './storage';
import { EnqueueResult, QueueStatus } from './types';
import { FalStream } from './streaming';
import { EnqueueResult, QueueStatus, RequestLog } from './types';
import { ensureAppIdFormat, isUUIDv4, isValidUrl, parseAppId } from './utils';

/**
Expand Down Expand Up @@ -138,36 +140,22 @@ export async function subscribe<Input, Output>(
if (options.onEnqueue) {
options.onEnqueue(requestId);
}
return new Promise<Output>((resolve, reject) => {
let timeoutId: ReturnType<typeof setTimeout>;
const pollInterval = options.pollInterval ?? 1000;
const poll = async () => {
try {
const requestStatus = await queue.status(id, {
requestId,
logs: options.logs ?? false,
});
if (options.onQueueUpdate) {
options.onQueueUpdate(requestStatus);
}
if (requestStatus.status === 'COMPLETED') {
clearTimeout(timeoutId);
try {
const result = await queue.result<Output>(id, { requestId });
resolve(result);
} catch (error) {
reject(error);
}
return;
}
timeoutId = setTimeout(poll, pollInterval);
} catch (error) {
clearTimeout(timeoutId);
reject(error);
const status = await queue.streamStatus(id, {
requestId,
logs: options.logs,
});
const logs: RequestLog[] = [];
status.on('message', (data: QueueStatus) => {
if (options.onQueueUpdate) {
// accumulate logs to match previous polling behavior
if ('logs' in data && Array.isArray(data.logs) && data.logs.length > 0) {
logs.push(...data.logs);
}
};
poll().catch(reject);
options.onQueueUpdate('logs' in data ? { ...data, logs } : data);
}
});
await status.done();
return queue.result<Output>(id, { requestId });
}

/**
Expand All @@ -177,6 +165,9 @@ type QueueSubscribeOptions = {
/**
* The interval (in milliseconds) at which to poll for updates.
* If not provided, a default value of `1000` will be used.
*
* @deprecated starting from v0.12.0 the queue status is streamed
* using the `queue.subscribeToStatus` method.
*/
pollInterval?: number;

Expand Down Expand Up @@ -239,40 +230,48 @@ interface Queue {
/**
* Submits a request to the queue.
*
* @param id - The ID or URL of the function web endpoint.
* @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run.
* @returns A promise that resolves to the result of enqueuing the request.
*/
submit<Input>(
id: string,
endpointId: string,
options: SubmitOptions<Input>
): Promise<EnqueueResult>;

/**
* Retrieves the status of a specific request in the queue.
*
* @param id - The ID or URL of the function web endpoint.
* @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run.
* @returns A promise that resolves to the status of the request.
*/
status(id: string, options: QueueStatusOptions): Promise<QueueStatus>;
status(endpointId: string, options: QueueStatusOptions): Promise<QueueStatus>;

/**
* Retrieves the result of a specific request from the queue.
*
* @param id - The ID or URL of the function web endpoint.
* @param endpointId - The ID or URL of the function web endpoint.
* @param options - Options to configure how the request is run.
* @returns A promise that resolves to the result of the request.
*/
result<Output>(id: string, options: BaseQueueOptions): Promise<Output>;
result<Output>(
endpointId: string,
options: BaseQueueOptions
): Promise<Output>;

/**
* @deprecated Use `fal.subscribe` instead.
*/
subscribe<Input, Output>(
id: string,
endpointId: string,
options: RunOptions<Input> & QueueSubscribeOptions
): Promise<Output>;

streamStatus(
endpointId: string,
options: QueueStatusOptions
): Promise<FalStream<unknown, QueueStatus>>;
}

/**
Expand All @@ -282,11 +281,11 @@ interface Queue {
*/
export const queue: Queue = {
async submit<Input>(
id: string,
endpointId: string,
options: SubmitOptions<Input>
): Promise<EnqueueResult> {
const { webhookUrl, path = '', ...runOptions } = options;
return send(id, {
return send(endpointId, {
...runOptions,
subdomain: 'queue',
method: 'post',
Expand All @@ -295,10 +294,10 @@ export const queue: Queue = {
});
},
async status(
id: string,
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
): Promise<QueueStatus> {
const appId = parseAppId(id);
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
return send(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
Expand All @@ -309,11 +308,33 @@ export const queue: Queue = {
},
});
},
async streamStatus(
endpointId: string,
{ requestId, logs = false }: QueueStatusOptions
): Promise<FalStream<unknown, QueueStatus>> {
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
const token = await getTemporaryAuthToken(endpointId);
const url = buildUrl(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
path: `/requests/${requestId}/status/stream`,
});

const queryParams = new URLSearchParams({
fal_jwt_token: token,
logs: logs ? '1' : '0',
});

return new FalStream<unknown, QueueStatus>(`${url}?${queryParams}`, {
input: {},
method: 'get',
});
},
async result<Output>(
id: string,
endpointId: string,
{ requestId }: BaseQueueOptions
): Promise<Output> {
const appId = parseAppId(id);
const appId = parseAppId(endpointId);
const prefix = appId.namespace ? `${appId.namespace}/` : '';
return send(`${prefix}${appId.owner}/${appId.alias}`, {
subdomain: 'queue',
Expand Down
Loading

0 comments on commit 4ea43b4

Please sign in to comment.