Skip to content

Commit

Permalink
update chat logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Feb 12, 2024
1 parent f463fe2 commit cd17cf8
Showing 1 changed file with 23 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@
content: string;
}
// TODO: replace with ChatML
interface Conversation {
generated_responses: string[];
past_user_inputs: string[];
}
interface Response {
conversation: Conversation;
interface Response {
generated_text: string;
}
Expand All @@ -41,13 +35,7 @@
}>;
let computeTime = "";
let conversation: {
generated_responses: string[];
past_user_inputs: string[];
} = {
generated_responses: [],
past_user_inputs: [],
};
let chat: Message[] = [];
let error: string = "";
let isLoading = false;
let modelLoading = {
Expand All @@ -69,16 +57,14 @@
return;
}
if (shouldUpdateUrl && !conversation.past_user_inputs.length) {
if (shouldUpdateUrl && !chat.length) {
updateUrl({ text: trimmedText });
}
// const chat = [
// { role: "user", content: "Hello, how are you?" },
// { role: "assistant", content: "I'm doing great. How can I help you today?" },
// { role: "user", content: "I'd like to show off how chat templating works!" },
// ];
const chat: Message[] = []; // TODO
// Add user message to chat
chat = chat.concat([{ role: "user", content: trimmedText }]);
// Render chat template
const chatTemplate = model.config?.tokenizer?.chat_template;
if (chatTemplate === undefined) {
outputJson = "";
Expand All @@ -104,7 +90,7 @@
model.id,
requestBody,
apiToken,
parseOutput,
body => parseOutput(body, chat),
withModelLoading,
includeCredentials,
isOnLoadCall
Expand All @@ -121,7 +107,7 @@
computeTime = res.computeTime;
outputJson = res.outputJson;
if (res.output) {
conversation = res.output.conversation;
chat = res.output.chat;
output = res.output.output;
}
// Emptying input value
Expand All @@ -138,32 +124,27 @@
}
function isValidOutput(arg: any): arg is Response {
return (
arg && Array.isArray(arg?.conversation?.generated_responses) && Array.isArray(arg?.conversation?.past_user_inputs)
);
return typeof(arg?.generated_text) === "string";
}
function parseOutput(body: unknown): {
conversation: Conversation;
function parseOutput(body: unknown, chat: Message[]): {
chat: Message[];
output: Output;
} {
if (isValidOutput(body)) {
const conversation = body.conversation;
const pastUserInputs = conversation.past_user_inputs;
const generatedResponses = conversation.generated_responses;
const output = pastUserInputs
.filter(
(x, i) =>
x !== null && x !== undefined && generatedResponses[i] !== null && generatedResponses[i] !== undefined
)
.map((x, i) => ({
input: x ?? "",
response: generatedResponses[i] ?? "",
}));
return { conversation, output };
const chatWithOutput = chat.concat([{ role: "assistant", content: body.generated_text }]);
const output = chatWithOutput.reduce((acc, message, index) => {
if (index % 2 === 0) {
acc.push({ input: message.content, response: chatWithOutput[index + 1].content });
}
return acc;
}, [] as Output);
return { chat: chatWithOutput, output };
}
throw new TypeError(
"Invalid output: output must be of type <conversation: <generated_responses:Array; past_user_inputs:Array>>"
"Invalid output: output must be of type <generated_text: string>"
);
}
Expand Down

0 comments on commit cd17cf8

Please sign in to comment.