Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chat-panel): save chat state and reload state when webview reloading #3664

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions clients/tabby-chat-panel/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,22 @@ export interface ClientApiMethods {
* @returns The active selection of active editor.
*/
getActiveEditorSelection: () => Promise<EditorFileContext | null>

/**
* Fetch the saved session state from the client.
* When initialized, the chat panel attempts to fetch the saved session state to restore the session.
* @param keys The keys to be fetched. If not provided, all keys will be returned.
* @return The saved persisted state, or null if no state is found.
*/
fetchSessionState?: (keys?: string[] | undefined) => Promise<Record<string, unknown> | null>

/**
* Save the session state of the chat panel.
* The client is responsible for maintaining the state in case of a webview reload.
* The saved state should be merged and updated by the record key.
* @param state The state to save.
*/
storeSessionState?: (state: Record<string, unknown>) => Promise<void>
}

export interface ClientApi extends ClientApiMethods {
Expand All @@ -303,6 +319,8 @@ export function createClient(target: HTMLIFrameElement, api: ClientApiMethods):
openExternal: api.openExternal,
readWorkspaceGitRepositories: api.readWorkspaceGitRepositories,
getActiveEditorSelection: api.getActiveEditorSelection,
fetchSessionState: api.fetchSessionState,
storeSessionState: api.storeSessionState,
},
})
}
Expand Down
26 changes: 26 additions & 0 deletions ee/tabby-ui/app/chat/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ export default function ChatPage() {
supportsReadWorkspaceGitRepoInfo,
setSupportsReadWorkspaceGitRepoInfo
] = useState(false)
const [
supportsStoreAndFetchSessionState,
setSupportsStoreAndFetchSessionState
] = useState(false)

const executeCommand = (command: ChatCommand) => {
if (chatRef.current) {
Expand Down Expand Up @@ -244,6 +248,14 @@ export default function ChatPage() {
server
?.hasCapability('readWorkspaceGitRepositories')
.then(setSupportsReadWorkspaceGitRepoInfo)
Promise.all([
server?.hasCapability('fetchSessionState'),
server?.hasCapability('storeSessionState')
]).then(results => {
setSupportsStoreAndFetchSessionState(
results.every(result => !!result)
)
})
}

checkCapabilities().then(() => {
Expand Down Expand Up @@ -304,6 +316,14 @@ export default function ChatPage() {
return server?.getActiveEditorSelection() ?? null
}

const fetchSessionState = async () => {
return server?.fetchSessionState?.() ?? null
}

const storeSessionState = async (state: Record<string, any>) => {
return server?.storeSessionState?.(state)
}

const refresh = async () => {
setIsRefreshLoading(true)
await server?.refresh()
Expand Down Expand Up @@ -427,6 +447,12 @@ export default function ChatPage() {
: undefined
}
getActiveEditorSelection={getActiveEditorSelection}
fetchSessionState={
supportsStoreAndFetchSessionState ? fetchSessionState : undefined
}
storeSessionState={
supportsStoreAndFetchSessionState ? storeSessionState : undefined
}
/>
</ErrorBoundary>
)
Expand Down
161 changes: 129 additions & 32 deletions ee/tabby-ui/components/chat/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ interface ChatProps extends React.ComponentProps<'div'> {
supportsOnApplyInEditorV2: boolean
readWorkspaceGitRepositories?: () => Promise<GitRepository[]>
getActiveEditorSelection?: () => Promise<EditorFileContext | null>
fetchSessionState?: () => Promise<SessionState | null>
storeSessionState?: (state: Partial<SessionState>) => Promise<void>
}

/**
* The state used to restore the chat panel, should be json serializable.
* Save this state to client so that the chat panel can be restored across webview reloading.
*/
export interface SessionState {
threadId?: string | undefined
qaPairs?: QuestionAnswerPair[] | undefined
input?: string | undefined
relevantContext?: Context[] | undefined
selectedRepoId?: string | undefined
}

function ChatRenderer(
Expand All @@ -144,7 +158,9 @@ function ChatRenderer(
chatInputRef,
supportsOnApplyInEditorV2,
readWorkspaceGitRepositories,
getActiveEditorSelection
getActiveEditorSelection,
fetchSessionState,
storeSessionState
}: ChatProps,
ref: React.ForwardedRef<ChatRef>
) {
Expand All @@ -158,11 +174,24 @@ function ChatRenderer(
const [activeSelection, setActiveSelection] = React.useState<Context | null>(
null
)

React.useEffect(() => {
if (isDataSetup) {
storeSessionState?.({ input })
}
}, [input, isDataSetup, storeSessionState])

// sourceId
const [selectedRepoId, setSelectedRepoId] = React.useState<
string | undefined
>()

React.useEffect(() => {
if (isDataSetup) {
storeSessionState?.({ selectedRepoId })
}
}, [selectedRepoId, isDataSetup, storeSessionState])

const enableActiveSelection = useChatStore(
state => state.enableActiveSelection
)
Expand Down Expand Up @@ -196,6 +225,9 @@ function ChatRenderer(

const nextQaPairs = qaPairs.filter(o => o.user.id !== userMessageId)
setQaPairs(nextQaPairs)
storeSessionState?.({
qaPairs: nextQaPairs
})

deleteThreadMessagePair(threadId, qaPair?.user.id, qaPair?.assistant?.id)
}
Expand Down Expand Up @@ -226,6 +258,9 @@ function ChatRenderer(
}
]
setQaPairs(nextQaPairs)
storeSessionState?.({
qaPairs: nextQaPairs
})
const [userMessage, threadRunOptions] = generateRequestPayload(
qaPair.user
)
Expand Down Expand Up @@ -254,11 +289,18 @@ function ChatRenderer(
nextClientContext = nextClientContext.concat(userMessage.relevantContext)
}

setRelevantContext(uniqWith(nextClientContext, isEqual))
const updatedRelevantContext = uniqWith(nextClientContext, isEqual)
setRelevantContext(updatedRelevantContext)

// delete message pair
const nextQaPairs = qaPairs.filter(o => o.user.id !== userMessageId)
setQaPairs(nextQaPairs)

storeSessionState?.({
qaPairs: nextQaPairs,
relevantContext: updatedRelevantContext
})

setInput(userMessage.message)
if (userMessage.activeContext) {
openInEditor(getFileLocationFromContext(userMessage.activeContext))
Expand All @@ -282,6 +324,10 @@ function ChatRenderer(
stop(true)
setQaPairs([])
setThreadId(undefined)
storeSessionState?.({
qaPairs: [],
threadId: undefined
})
}

const handleMessageAction = (
Expand Down Expand Up @@ -311,6 +357,9 @@ function ChatRenderer(
// update threadId
if (answer.threadId && !threadId) {
setThreadId(answer.threadId)
storeSessionState?.({
threadId: answer.threadId
})
}

setQaPairs(prev => {
Expand All @@ -334,21 +383,28 @@ function ChatRenderer(
}
]
})
}, [answer, isLoading])

const scrollToBottom = useDebounceCallback(() => {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
})
} else {
window.scrollTo({
top: document.body.offsetHeight,
behavior: 'smooth'
})
if (!isLoading) {
storeSessionState?.({ qaPairs })
}
}, 100)
}, [answer, isLoading])

const scrollToBottom = useDebounceCallback(
(behavior: ScrollBehavior = 'smooth') => {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior
})
} else {
window.scrollTo({
top: document.body.offsetHeight,
behavior
})
}
},
100
)

React.useLayoutEffect(() => {
// scroll to bottom when a request is sent
Expand All @@ -361,7 +417,7 @@ function ChatRenderer(
if (error && qaPairs?.length) {
setQaPairs(prev => {
let lastQaPairs = prev[prev.length - 1]
return [
const nextQaPairs = [
...prev.slice(0, prev.length - 1),
{
...lastQaPairs,
Expand All @@ -373,6 +429,10 @@ function ChatRenderer(
}
}
]
storeSessionState?.({
qaPairs: nextQaPairs
})
return nextQaPairs
})
}

Expand Down Expand Up @@ -467,6 +527,9 @@ function ChatRenderer(
]

setQaPairs(nextQaPairs)
storeSessionState?.({
qaPairs: nextQaPairs
})

sendUserMessage(...generateRequestPayload(newUserMessage))
}
Expand Down Expand Up @@ -494,10 +557,19 @@ function ChatRenderer(
relevantContext: relevantContext
})
setRelevantContext([])
storeSessionState?.({
relevantContext: []
})
}

const handleAddRelevantContext = useLatest((context: Context) => {
setRelevantContext(oldValue => appendContextAndDedupe(oldValue, context))
setRelevantContext(oldValue => {
const updatedValue = appendContextAndDedupe(oldValue, context)
storeSessionState?.({
relevantContext: updatedValue
})
return updatedValue
})
})

const addRelevantContext = (editorContext: EditorContext) => {
Expand All @@ -509,6 +581,9 @@ function ChatRenderer(
const newRelevantContext = [...relevantContext]
newRelevantContext.splice(index, 1)
setRelevantContext(newRelevantContext)
storeSessionState?.({
relevantContext: newRelevantContext
})
}

React.useEffect(() => {
Expand Down Expand Up @@ -542,26 +617,48 @@ function ChatRenderer(

React.useEffect(() => {
const init = async () => {
const [workspaceGitRepositories, activeEditorSelecition] =
await Promise.all([
fetchWorkspaceGitRepo(),
initActiveEditorSelection()
])
const [persistedState, activeEditorSelection] = await Promise.all([
fetchSessionState?.(),
initActiveEditorSelection()
])

if (persistedState?.threadId) {
setThreadId(persistedState.threadId)
}
if (persistedState?.qaPairs) {
setQaPairs(persistedState.qaPairs)
}
if (persistedState?.input) {
setInput(persistedState.input)
}
if (persistedState?.relevantContext) {
setRelevantContext(persistedState.relevantContext)
}
scrollToBottom.run('instant')

// get default repository
if (workspaceGitRepositories?.length && repos?.length) {
const defaultGitUrl = workspaceGitRepositories[0].url
const repo = findClosestGitRepository(
repos.map(x => ({ url: x.gitUrl, sourceId: x.sourceId })),
defaultGitUrl
)
if (repo) {
setSelectedRepoId(repo.sourceId)
if (
persistedState?.selectedRepoId &&
repos?.find(x => x.sourceId === persistedState.selectedRepoId)
) {
setSelectedRepoId(persistedState.selectedRepoId)
} else {
const workspaceGitRepositories = await fetchWorkspaceGitRepo()
if (workspaceGitRepositories?.length && repos?.length) {
const defaultGitUrl = workspaceGitRepositories[0].url
const repo = findClosestGitRepository(
repos.map(x => ({ url: x.gitUrl, sourceId: x.sourceId })),
defaultGitUrl
)
if (repo) {
setSelectedRepoId(repo.sourceId)
}
}
}

// update active selection
if (activeEditorSelecition) {
const context = convertEditorContext(activeEditorSelecition)
if (activeEditorSelection) {
const context = convertEditorContext(activeEditorSelection)
setActiveSelection(context)
}
}
Expand Down
Loading