From 9b2c5dcf35f059f804ca4b6d52f50d02bdb56237 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Fri, 27 Sep 2024 04:09:02 +0800 Subject: [PATCH 1/8] Create callee definitions context --- .../kotlin/ai/devchat/common/Constants.kt | 21 +++ src/main/kotlin/ai/devchat/common/IDEUtils.kt | 125 ++++++++++++++++++ .../kotlin/ai/devchat/plugin/IDEServer.kt | 26 +--- .../devchat/plugin/completion/agent/Agent.kt | 14 +- .../plugin/completion/agent/AgentService.kt | 7 +- .../plugin/completion/agent/ContextBuilder.kt | 54 +++++++- .../plugin/hints/ChatCVProviderBase.kt | 16 +-- 7 files changed, 210 insertions(+), 53 deletions(-) create mode 100644 src/main/kotlin/ai/devchat/common/IDEUtils.kt diff --git a/src/main/kotlin/ai/devchat/common/Constants.kt b/src/main/kotlin/ai/devchat/common/Constants.kt index 1c7836f..2cd4701 100644 --- a/src/main/kotlin/ai/devchat/common/Constants.kt +++ b/src/main/kotlin/ai/devchat/common/Constants.kt @@ -3,4 +3,25 @@ package ai.devchat.common object Constants { val ASSISTANT_NAME_ZH = DevChatBundle.message("assistant.name.zh") val ASSISTANT_NAME_EN = DevChatBundle.message("assistant.name.en") + val FUNC_TYPE_NAMES: Set = setOf( + "FUN", // Kotlin + "METHOD", // Java + "FUNCTION_DEFINITION", // C, C++ + "Py:FUNCTION_DECLARATION", // Python + "FUNCTION_DECLARATION", "METHOD_DECLARATION", // Golang + "JS:FUNCTION_DECLARATION", "JS:FUNCTION_EXPRESSION", // JS + "JS:TYPESCRIPT_FUNCTION", "JS:TYPESCRIPT_FUNCTION_EXPRESSION", // TS + "CLASS_METHOD", // PHP + "FUNCTION", // PHP, Rust + "Ruby:METHOD", // Ruby + ) + val CALL_EXPRESSION_ELEMENT_TYPE_NAMES: Set = setOf( + "CALL_EXPRESSION", // Kotlin, C, C++, Python + "METHOD_CALL_EXPRESSION", // Java + "CALL_EXPR", // Go, Rust + "JS_CALL_EXPRESSION", // JS + "TS_CALL_EXPRESSION", // TS + "PHP_METHOD_REFERENCE", // PHP + "CALL", // Ruby + ) } \ No newline at end of file diff --git a/src/main/kotlin/ai/devchat/common/IDEUtils.kt b/src/main/kotlin/ai/devchat/common/IDEUtils.kt new file mode 100644 index 0000000..5c1ade8 --- /dev/null +++ b/src/main/kotlin/ai/devchat/common/IDEUtils.kt @@ -0,0 +1,125 @@ +package ai.devchat.common + +import com.intellij.lang.folding.LanguageFolding +import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.roots.ProjectFileIndex +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiFile +import com.intellij.psi.PsiPolyVariantReference +import com.intellij.psi.util.elementType +import com.intellij.psi.util.findParentInFile +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CountDownLatch + + +object IDEUtils { + fun runInEdtAndGet(block: () -> T): T { + val app = ApplicationManager.getApplication() + if (app.isDispatchThread) { + return block() + } + val future = CompletableFuture() + val latch = CountDownLatch(1) + app.invokeLater { + try { + val result = block() + future.complete(result) + } catch (e: Exception) { + future.completeExceptionally(e) + } finally { + latch.countDown() + } + } + latch.await() + return future.get() + } + + fun findCalleeInParent(element: PsiElement?): List? { + if (element == null) return null + Log.info("Find callee in parent: ${element.elementType}: ${element.text.replace("\n", "\\n")}") + val nearestCallExpression = element.findParentInFile(withSelf = true) { + if (it is PsiFile) false else { + it.elementType.toString() in Constants.CALL_EXPRESSION_ELEMENT_TYPE_NAMES + } + } + + if (nearestCallExpression == null) return null + + Log.info("Nearest call expression: ${nearestCallExpression.elementType}: ${nearestCallExpression.text.replace("\n", "\\n")}") + + val projectFileIndex = ProjectFileIndex.getInstance(element.project) + val callee = nearestCallExpression.children.asSequence() + .mapNotNull {child -> + child.reference?.let{ref -> + if (ref is PsiPolyVariantReference) { + ref.multiResolve(false).mapNotNull { it.element } + } else listOfNotNull(ref.resolve()) + }?.filter { + val containingFile = it.containingFile?.virtualFile + containingFile != null && projectFileIndex.isInContent(containingFile) + } + } + .firstOrNull {it.isNotEmpty()} + + if (callee == null) { + Log.info("Callee not found") + } else { + Log.info("Callee: $callee") + } + + return callee ?: findCalleeInParent(nearestCallExpression.parent) + } + + fun PsiElement.findCalleeInParent(): Sequence> { + val projectFileIndex = ProjectFileIndex.getInstance(this.project) + return generateSequence(this) { it.parent } + .takeWhile { it !is PsiFile } + .filter { it.elementType.toString() in Constants.CALL_EXPRESSION_ELEMENT_TYPE_NAMES } + .mapNotNull { callExpression -> + Log.info("Call expression: ${callExpression.elementType}: ${callExpression.text}") + + callExpression.children + .asSequence() + .mapNotNull { child -> + child.reference?.let { ref -> + when (ref) { + is PsiPolyVariantReference -> ref.multiResolve(false).mapNotNull { it.element } + else -> listOfNotNull(ref.resolve()) + }.filter { resolved -> + resolved.containingFile.virtualFile?.let { file -> + projectFileIndex.isInContent(file) + } == true + } + } + } + .firstOrNull { it.isNotEmpty() } + } + } + + fun PsiElement.getFoldedText(): String { + val file = this.containingFile + val document = file.viewProvider.document ?: return text + + val foldingBuilder = LanguageFolding.INSTANCE.forLanguage(this.language) ?: return text + val descriptors = foldingBuilder.buildFoldRegions(file.node, document) + + // Find the largest folding descriptor that is contained within the element's range + val bodyDescriptor = descriptors + .filter { + textRange.contains(it.range) + && it.element.textRange.startOffset > textRange.startOffset // Exclude the function itself + } + .sortedByDescending { it.range.length } + .getOrNull(0) + ?: return text + + val bodyStart = bodyDescriptor.range.startOffset - textRange.startOffset + val bodyEnd = bodyDescriptor.range.endOffset - textRange.startOffset + + return buildString { + append(text.substring(0, bodyStart)) + append(bodyDescriptor.placeholderText) + append(text.substring(bodyEnd)) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/ai/devchat/plugin/IDEServer.kt b/src/main/kotlin/ai/devchat/plugin/IDEServer.kt index fc6f6cf..2e0e3a2 100644 --- a/src/main/kotlin/ai/devchat/plugin/IDEServer.kt +++ b/src/main/kotlin/ai/devchat/plugin/IDEServer.kt @@ -1,5 +1,6 @@ package ai.devchat.plugin +import ai.devchat.common.IDEUtils.runInEdtAndGet import ai.devchat.common.Log import ai.devchat.common.Notifier import ai.devchat.common.PathUtils @@ -47,8 +48,6 @@ import kotlinx.serialization.Serializable import java.awt.Point import java.io.File import java.net.ServerSocket -import java.util.concurrent.CompletableFuture -import java.util.concurrent.CountDownLatch import kotlin.reflect.full.memberFunctions @@ -446,29 +445,6 @@ fun Editor.diffWith(newText: String, autoEdit: Boolean) { } } -fun runInEdtAndGet(block: () -> T): T { - val app = ApplicationManager.getApplication() - if (app.isDispatchThread) { - return block() - } - val future = CompletableFuture() - val latch = CountDownLatch(1) - app.invokeLater { - try { - val result = block() - future.complete(result) - } catch (e: Exception) { - future.completeExceptionally(e) - } finally { - latch.countDown() - } - } - latch.await() - return future.get() -} - - - fun Project.getPsiFile(filePath: String): PsiFile = runInEdtAndGet { ReadAction.compute { val virtualFile = LocalFileSystem.getInstance().findFileByIoFile(File(filePath)) diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt index 7fa40e5..2934dc9 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt @@ -4,6 +4,7 @@ import ai.devchat.storage.CONFIG import com.google.gson.Gson import com.google.gson.annotations.SerializedName import com.intellij.openapi.diagnostic.Logger +import com.intellij.psi.PsiFile import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.* import kotlinx.coroutines.launch @@ -47,9 +48,8 @@ class Agent(val scope: CoroutineScope) { } data class CompletionRequest( - val filepath: String, + val file: PsiFile, val language: String, - val text: String, val position: Int, val manually: Boolean?, ) @@ -100,8 +100,9 @@ class Agent(val scope: CoroutineScope) { ) { companion object { fun fromCompletionRequest(completionRequest: CompletionRequest): RequestInfo { - val upperPart = completionRequest.text.substring(0, completionRequest.position) - val lowerPart = completionRequest.text.substring(completionRequest.position) + val fileContent = completionRequest.file.text + val upperPart = fileContent.substring(0, completionRequest.position) + val lowerPart = fileContent.substring(completionRequest.position) val currentLinePrefix = upperPart.substringAfterLast(LINE_SEPARATOR, upperPart) val currentLineSuffix = lowerPart.lineSequence().firstOrNull()?.second ?: "" val currentIndent = currentLinePrefix.takeWhile { it.isWhitespace() }.length @@ -112,7 +113,7 @@ class Agent(val scope: CoroutineScope) { i > 0 && v.second.trim().isNotEmpty() }?.value?.second return RequestInfo( - filepath = completionRequest.filepath, + filepath = completionRequest.file.virtualFile.path, language = completionRequest.language, upperPart = upperPart, lowerPart = lowerPart, @@ -277,8 +278,7 @@ class Agent(val scope: CoroutineScope) { val model = CONFIG["complete_model"] as? String var startTime = System.currentTimeMillis() val prompt = ContextBuilder( - completionRequest.filepath, - completionRequest.text, + completionRequest.file, completionRequest.position ).createPrompt(model) val promptBuildingElapse = System.currentTimeMillis() - startTime diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt index ae1992b..45b6fc5 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt @@ -1,6 +1,5 @@ package ai.devchat.plugin.completion.agent -import ai.devchat.storage.CONFIG import com.intellij.lang.Language import com.intellij.openapi.Disposable import com.intellij.openapi.application.ReadAction @@ -9,7 +8,8 @@ import com.intellij.openapi.editor.Editor import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiFile import io.ktor.util.* -import kotlinx.coroutines.* +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers @Service class AgentService : Disposable { @@ -24,9 +24,8 @@ class AgentService : Disposable { }?.let { file -> agent.provideCompletions( Agent.CompletionRequest( - file.virtualFile.path, + file, file.getLanguageId(), - editor.document.text, offset, manually, ) diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index b38d6ac..7042010 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -1,8 +1,19 @@ package ai.devchat.plugin.completion.agent +import ai.devchat.common.IDEUtils.findCalleeInParent +import ai.devchat.common.IDEUtils.getFoldedText +import ai.devchat.common.IDEUtils.runInEdtAndGet +import ai.devchat.common.Log +import com.intellij.psi.PsiFile + const val MAX_CONTEXT_TOKENS = 6000 const val LINE_SEPARATOR = '\n' +data class CodeSnippet( + val filepath: String, + val content: String, +) + fun String.tokenCount(): Int { var count = 0 @@ -52,7 +63,13 @@ fun String.lineSequenceReversed(offset: Int? = null) = sequence { } } -class ContextBuilder(val filepath: String, val content: String, val offset: Int) { +class ContextBuilder(val file: PsiFile, val offset: Int) { + val filepath: String = file.virtualFile.path + val content: String = file.text + // TODO: get comment prefix for different languages + private val commentPrefix: String = "//" + private var tokenCount: Int = 0 + private fun buildFileContext(): Pair { val maxTokens = MAX_CONTEXT_TOKENS * 0.35 @@ -64,6 +81,7 @@ class ContextBuilder(val filepath: String, val content: String, val offset: Int) prefixTokens += numTokens true }.lastOrNull()?.first?.first ?: 0 + tokenCount += prefixTokens val maxSuffixTokens = maxTokens - prefixTokens var suffixTokens = 0 @@ -73,6 +91,7 @@ class ContextBuilder(val filepath: String, val content: String, val offset: Int) suffixTokens += numTokens true }.lastOrNull()?.first?.last ?: content.length + tokenCount += suffixTokens return Pair( content.substring(prefixStart, offset), @@ -80,11 +99,40 @@ class ContextBuilder(val filepath: String, val content: String, val offset: Int) ) } + private fun buildCalleeDefinitionsContext(): String { + fun checkAndUpdateTokenCount(snippet: CodeSnippet): Boolean { + val newCount = tokenCount + snippet.content.tokenCount() + return (newCount <= MAX_CONTEXT_TOKENS).also { if (it) tokenCount = newCount } + } + + return runInEdtAndGet { + file.findElementAt(offset) + ?.findCalleeInParent() + ?.flatMap { elements -> elements.filter { it.containingFile.virtualFile.path != filepath } } + ?.map { CodeSnippet(it.containingFile.virtualFile.path, it.getFoldedText()) } + ?.takeWhile(::checkAndUpdateTokenCount) + ?.joinToString(separator = "") { + "$commentPrefixcall function define:\n\n${it.filepath}\n\n${it.content}\n\n\n\n" + } ?: "" + } + } + + fun createPrompt(model: String?): String { val (prefix, suffix) = buildFileContext() + val extras: String = listOf( +// taskDescriptionContextWithCommentPrefix, +// neighborFileContext, +// recentEditContext, +// symbolContext, + buildCalleeDefinitionsContext() +// similarBlockContext, +// gitDiffContext, + ).joinToString("") + Log.info("Extras completion context:\n$extras") return if (!model.isNullOrEmpty() && model.contains("deepseek")) - "<|fim▁begin|>$filepath\n\n$prefix<|fim▁hole|>$suffix<|fim▁end|>" + "<|fim▁begin|>$extras$filepath\n\n$prefix<|fim▁hole|>$suffix<|fim▁end|>" else - "$filepath\n\n$prefix$suffix" + "$extras$filepath\n\n$prefix$suffix" } } \ No newline at end of file diff --git a/src/main/kotlin/ai/devchat/plugin/hints/ChatCVProviderBase.kt b/src/main/kotlin/ai/devchat/plugin/hints/ChatCVProviderBase.kt index d0a226e..627b858 100644 --- a/src/main/kotlin/ai/devchat/plugin/hints/ChatCVProviderBase.kt +++ b/src/main/kotlin/ai/devchat/plugin/hints/ChatCVProviderBase.kt @@ -1,6 +1,7 @@ package ai.devchat.plugin.hints import ai.devchat.common.Constants.ASSISTANT_NAME_ZH +import ai.devchat.common.Constants.FUNC_TYPE_NAMES import ai.devchat.core.DevChatActions import ai.devchat.core.handlers.SendUserMessageHandler import ai.devchat.plugin.DevChatService @@ -85,17 +86,4 @@ abstract class ChatCVProviderBase : CodeVisionProviderBase() { handleClick(editor, element, event) } } -} - -internal val FUNC_TYPE_NAMES: Set = setOf( - "FUN", // Kotlin - "METHOD", // Java - "FUNCTION_DEFINITION", // C, C++ - "Py:FUNCTION_DECLARATION", // Python - "FUNCTION_DECLARATION", "METHOD_DECLARATION", // Golang - "JS:FUNCTION_DECLARATION", "JS:FUNCTION_EXPRESSION", // JS - "JS:TYPESCRIPT_FUNCTION", "JS:TYPESCRIPT_FUNCTION_EXPRESSION", // TS - "CLASS_METHOD", // PHP - "FUNCTION", // PHP, Rust - "Ruby:METHOD", // Ruby -) \ No newline at end of file +} \ No newline at end of file From 125099dfa52005c56184ca1c9ecf36d1f7a60a34 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 00:15:58 +0800 Subject: [PATCH 2/8] Create symbol type definition context --- src/main/kotlin/ai/devchat/common/IDEUtils.kt | 111 +++++++++++++++--- .../plugin/completion/agent/ContextBuilder.kt | 58 ++++++--- 2 files changed, 135 insertions(+), 34 deletions(-) diff --git a/src/main/kotlin/ai/devchat/common/IDEUtils.kt b/src/main/kotlin/ai/devchat/common/IDEUtils.kt index 5c1ade8..0dca447 100644 --- a/src/main/kotlin/ai/devchat/common/IDEUtils.kt +++ b/src/main/kotlin/ai/devchat/common/IDEUtils.kt @@ -1,13 +1,20 @@ package ai.devchat.common +import com.intellij.codeInsight.navigation.actions.TypeDeclarationProvider +import com.intellij.lang.folding.FoldingDescriptor import com.intellij.lang.folding.LanguageFolding import com.intellij.openapi.application.ApplicationManager +import com.intellij.openapi.application.ReadAction import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiNameIdentifierOwner import com.intellij.psi.PsiPolyVariantReference import com.intellij.psi.util.elementType import com.intellij.psi.util.findParentInFile +import com.intellij.refactoring.suggested.startOffset +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking import java.util.concurrent.CompletableFuture import java.util.concurrent.CountDownLatch @@ -82,10 +89,12 @@ object IDEUtils { .asSequence() .mapNotNull { child -> child.reference?.let { ref -> - when (ref) { - is PsiPolyVariantReference -> ref.multiResolve(false).mapNotNull { it.element } - else -> listOfNotNull(ref.resolve()) - }.filter { resolved -> + if (ref is PsiPolyVariantReference) { + ref.multiResolve(false).mapNotNull { it.element } + } else { + listOfNotNull(ref.resolve()) + } + .filter { resolved -> resolved.containingFile.virtualFile?.let { file -> projectFileIndex.isInContent(file) } == true @@ -96,30 +105,94 @@ object IDEUtils { } } - fun PsiElement.getFoldedText(): String { + + private fun PsiElement.getTypeDeclaration(): PsiElement? = runBlocking(Dispatchers.IO) { + ReadAction.compute { + TypeDeclarationProvider.EP_NAME.extensionList.asSequence() + .mapNotNull { provider -> + provider.getSymbolTypeDeclarations(this@getTypeDeclaration)?.firstOrNull() + } + .firstOrNull() + } + } + + data class CodeNode( + val element: PsiElement, + val isProjectContent: Boolean, + ) + data class SymbolTypeDeclaration( + val symbol: PsiNameIdentifierOwner, + val typeDeclaration: CodeNode + ) + + fun PsiElement.findAccessibleVariables(): Sequence { + val projectFileIndex = ProjectFileIndex.getInstance(this.project) + return generateSequence(this.parent) { it.parent } + .takeWhile { it !is PsiFile } + .flatMap { it.children.asSequence().filterIsInstance() } + .plus(this.containingFile.children.asSequence().filterIsInstance()) + .filter { !it.name.isNullOrEmpty() && it.nameIdentifier != null } + .mapNotNull { + val typeDeclaration = it.getTypeDeclaration() ?: return@mapNotNull null + val virtualFile = typeDeclaration.containingFile.virtualFile ?: return@mapNotNull null + val isProjectContent = projectFileIndex.isInContent(virtualFile) + SymbolTypeDeclaration(it, CodeNode(typeDeclaration, isProjectContent)) + } + } + + fun PsiElement.foldTextOfLevel(foldingLevel: Int = 1): String { val file = this.containingFile val document = file.viewProvider.document ?: return text + val fileNode = file.node ?: return text val foldingBuilder = LanguageFolding.INSTANCE.forLanguage(this.language) ?: return text - val descriptors = foldingBuilder.buildFoldRegions(file.node, document) - - // Find the largest folding descriptor that is contained within the element's range - val bodyDescriptor = descriptors + val descriptors = foldingBuilder.buildFoldRegions(fileNode, document) .filter { textRange.contains(it.range) - && it.element.textRange.startOffset > textRange.startOffset // Exclude the function itself +// && it.element.textRange.startOffset > textRange.startOffset // Exclude the function itself + } + .sortedBy { it.range.startOffset } + .let { + findDescriptorsOfFoldingLevel(it, foldingLevel) } - .sortedByDescending { it.range.length } - .getOrNull(0) - ?: return text + return foldTextByDescriptors(descriptors) + } - val bodyStart = bodyDescriptor.range.startOffset - textRange.startOffset - val bodyEnd = bodyDescriptor.range.endOffset - textRange.startOffset + private fun findDescriptorsOfFoldingLevel( + descriptors: List, + foldingLevel: Int + ): List { + val nestedDescriptors = mutableListOf() + val stack = mutableListOf() - return buildString { - append(text.substring(0, bodyStart)) - append(bodyDescriptor.placeholderText) - append(text.substring(bodyEnd)) + for (descriptor in descriptors.sortedBy { it.range.startOffset }) { + while (stack.isNotEmpty() && !stack.last().range.contains(descriptor.range)) { + stack.removeAt(stack.size - 1) + } + stack.add(descriptor) + if (stack.size == foldingLevel) { + nestedDescriptors.add(descriptor) + } } + + return nestedDescriptors + } + + private fun PsiElement.foldTextByDescriptors(descriptors: List): String { + val sortedDescriptors = descriptors.sortedBy { it.range.startOffset } + val builder = StringBuilder() + var currentIndex = 0 + + for (descriptor in sortedDescriptors) { + val range = descriptor.range.shiftRight(-startOffset) + if (range.startOffset >= currentIndex) { + builder.append(text, currentIndex, range.startOffset) + builder.append(descriptor.placeholderText) + currentIndex = range.endOffset + } + } + builder.append(text.substring(currentIndex)) + + return builder.toString() } } \ No newline at end of file diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index 7042010..e9f7957 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -1,7 +1,8 @@ package ai.devchat.plugin.completion.agent +import ai.devchat.common.IDEUtils.findAccessibleVariables import ai.devchat.common.IDEUtils.findCalleeInParent -import ai.devchat.common.IDEUtils.getFoldedText +import ai.devchat.common.IDEUtils.foldTextOfLevel import ai.devchat.common.IDEUtils.runInEdtAndGet import ai.devchat.common.Log import com.intellij.psi.PsiFile @@ -9,11 +10,6 @@ import com.intellij.psi.PsiFile const val MAX_CONTEXT_TOKENS = 6000 const val LINE_SEPARATOR = '\n' -data class CodeSnippet( - val filepath: String, - val content: String, -) - fun String.tokenCount(): Int { var count = 0 @@ -63,6 +59,11 @@ fun String.lineSequenceReversed(offset: Int? = null) = sequence { } } +data class CodeSnippet ( + val filepath: String, + val content: String +) + class ContextBuilder(val file: PsiFile, val offset: Int) { val filepath: String = file.virtualFile.path val content: String = file.text @@ -99,17 +100,17 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { ) } - private fun buildCalleeDefinitionsContext(): String { - fun checkAndUpdateTokenCount(snippet: CodeSnippet): Boolean { - val newCount = tokenCount + snippet.content.tokenCount() - return (newCount <= MAX_CONTEXT_TOKENS).also { if (it) tokenCount = newCount } - } + private fun checkAndUpdateTokenCount(snippet: CodeSnippet): Boolean { + val newCount = tokenCount + snippet.content.tokenCount() + return (newCount <= MAX_CONTEXT_TOKENS).also { if (it) tokenCount = newCount } + } + private fun buildCalleeDefinitionsContext(): String { return runInEdtAndGet { file.findElementAt(offset) ?.findCalleeInParent() ?.flatMap { elements -> elements.filter { it.containingFile.virtualFile.path != filepath } } - ?.map { CodeSnippet(it.containingFile.virtualFile.path, it.getFoldedText()) } + ?.map { CodeSnippet(it.containingFile.virtualFile.path, it.foldTextOfLevel(1)) } ?.takeWhile(::checkAndUpdateTokenCount) ?.joinToString(separator = "") { "$commentPrefixcall function define:\n\n${it.filepath}\n\n${it.content}\n\n\n\n" @@ -117,15 +118,42 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { } } + private fun buildSymbolsContext(): String { + return runInEdtAndGet { + file.findElementAt(offset) + ?.findAccessibleVariables() + ?.filter { it.typeDeclaration.element.containingFile.virtualFile.path != filepath } + ?.map { + val typeElement = it.typeDeclaration.element + it.symbol.name to CodeSnippet( + typeElement.containingFile.virtualFile.path, + if (it.typeDeclaration.isProjectContent) { + typeElement.foldTextOfLevel(2) + } else { + typeElement.text.lines().first() + "..." + } + ) + } + ?.takeWhile { checkAndUpdateTokenCount(it.second) } + ?.joinToString(separator = "") {(name, snippet) -> + val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { + "$commentPrefix $it" + } + "$commentPrefix Symbol type definition:\n\n" + + "$commentPrefix ${name}\n\n" + + "$commentPrefix ${snippet.filepath}\n\n" + + "$commentPrefix \n$commentedContent\n\n\n\n" + } ?: "" + } + } fun createPrompt(model: String?): String { val (prefix, suffix) = buildFileContext() val extras: String = listOf( // taskDescriptionContextWithCommentPrefix, // neighborFileContext, -// recentEditContext, -// symbolContext, - buildCalleeDefinitionsContext() + buildCalleeDefinitionsContext(), + buildSymbolsContext(), // similarBlockContext, // gitDiffContext, ).joinToString("") From 0e0844ebe4e1f58e0965916455d18944b407b905 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 00:21:03 +0800 Subject: [PATCH 3/8] Create recently open files context --- .../plugin/completion/agent/ContextBuilder.kt | 21 ++++++++ .../ai/devchat/storage/RecentFilesTracker.kt | 50 +++++++++++++++++++ src/main/resources/META-INF/plugin.xml | 1 + 3 files changed, 72 insertions(+) create mode 100644 src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index e9f7957..956fe83 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -5,7 +5,10 @@ import ai.devchat.common.IDEUtils.findCalleeInParent import ai.devchat.common.IDEUtils.foldTextOfLevel import ai.devchat.common.IDEUtils.runInEdtAndGet import ai.devchat.common.Log +import ai.devchat.storage.RecentFilesTracker +import com.intellij.openapi.vfs.isFile import com.intellij.psi.PsiFile +import com.intellij.psi.util.PsiUtilCore.getPsiFile const val MAX_CONTEXT_TOKENS = 6000 const val LINE_SEPARATOR = '\n' @@ -147,6 +150,23 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { } } + private fun buildRecentFilesContext(): String { + val project = file.project + return project.getService(RecentFilesTracker::class.java).getRecentFiles().asSequence() + .filter { it.isFile && it.path != filepath } + .map { CodeSnippet(it.path, getPsiFile(project, it).foldTextOfLevel(3)) } + .filter { it.content.lines().count(String::isBlank) <= 50 } + .takeWhile(::checkAndUpdateTokenCount) + .joinToString(separator = "") {snippet -> + val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { + "$commentPrefix $it" + } + "$commentPrefix Recently open file:\n\n" + + "$commentPrefix ${snippet.filepath}\n\n" + + "$commentedContent\n\n\n\n" + } + } + fun createPrompt(model: String?): String { val (prefix, suffix) = buildFileContext() val extras: String = listOf( @@ -154,6 +174,7 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { // neighborFileContext, buildCalleeDefinitionsContext(), buildSymbolsContext(), + buildRecentFilesContext(), // similarBlockContext, // gitDiffContext, ).joinToString("") diff --git a/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt b/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt new file mode 100644 index 0000000..80488a3 --- /dev/null +++ b/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt @@ -0,0 +1,50 @@ +package ai.devchat.storage + +import ai.devchat.common.Log +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.fileEditor.FileEditorManager +import com.intellij.openapi.fileEditor.FileEditorManagerListener +import com.intellij.openapi.project.Project +import com.intellij.openapi.startup.ProjectActivity +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.openapi.vfs.isFile +import com.intellij.util.messages.MessageBusConnection + + +@Service(Service.Level.PROJECT) +class RecentFilesTracker(private val project: Project) { + private val maxSize = 3 + + private val recentFiles: MutableList = mutableListOf() + + init { + Log.info("RecentFilesTracker initialized for project: ${project.name}") + val connection: MessageBusConnection = project.messageBus.connect() + connection.subscribe(FileEditorManagerListener.FILE_EDITOR_MANAGER, object : FileEditorManagerListener { + override fun fileOpened(source: FileEditorManager, file: VirtualFile) { + if (file.isFile) { + addRecentFile(file) + } + } + }) + } + + private fun addRecentFile(file: VirtualFile) { + recentFiles.remove(file) + recentFiles.add(0, file) + if (recentFiles.size > maxSize) { + recentFiles.removeAt(recentFiles.size - 1) + } + } + + fun getRecentFiles(): List { + return recentFiles.toList() + } +} + +class RecentFilesStartupActivity : ProjectActivity { + override suspend fun execute(project: Project) { + project.service() + } +} diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 7bcf066..2954512 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -36,6 +36,7 @@ + From 4059bb9f4c03d105d0c5a90de5ba89b056ace1c9 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 01:38:34 +0800 Subject: [PATCH 4/8] Filter out files not in project & init with open files --- .../ai/devchat/storage/RecentFilesTracker.kt | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt b/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt index 80488a3..012df94 100644 --- a/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt +++ b/src/main/kotlin/ai/devchat/storage/RecentFilesTracker.kt @@ -1,11 +1,13 @@ package ai.devchat.storage import ai.devchat.common.Log +import com.intellij.openapi.application.runInEdt import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.fileEditor.FileEditorManager import com.intellij.openapi.fileEditor.FileEditorManagerListener import com.intellij.openapi.project.Project +import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.startup.ProjectActivity import com.intellij.openapi.vfs.VirtualFile import com.intellij.openapi.vfs.isFile @@ -14,9 +16,10 @@ import com.intellij.util.messages.MessageBusConnection @Service(Service.Level.PROJECT) class RecentFilesTracker(private val project: Project) { - private val maxSize = 3 + private val maxSize = 10 private val recentFiles: MutableList = mutableListOf() + private val projectFileIndex = ProjectFileIndex.getInstance(this.project) init { Log.info("RecentFilesTracker initialized for project: ${project.name}") @@ -28,13 +31,18 @@ class RecentFilesTracker(private val project: Project) { } } }) + + // Init with open files + FileEditorManager.getInstance(project).openFiles.forEach { addRecentFile(it) } } - private fun addRecentFile(file: VirtualFile) { - recentFiles.remove(file) - recentFiles.add(0, file) - if (recentFiles.size > maxSize) { - recentFiles.removeAt(recentFiles.size - 1) + private fun addRecentFile(file: VirtualFile) = runInEdt { + if (file.isFile && projectFileIndex.isInContent(file)) { + recentFiles.remove(file) + recentFiles.add(0, file) + if (recentFiles.size > maxSize) { + recentFiles.removeAt(recentFiles.size - 1) + } } } From 3772e4e8002947fa4cf05e4163ea78a4a6e0a928 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 01:39:49 +0800 Subject: [PATCH 5/8] Fix EDT error --- .../plugin/completion/agent/ContextBuilder.kt | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index 956fe83..cec5481 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -152,19 +152,21 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { private fun buildRecentFilesContext(): String { val project = file.project - return project.getService(RecentFilesTracker::class.java).getRecentFiles().asSequence() - .filter { it.isFile && it.path != filepath } - .map { CodeSnippet(it.path, getPsiFile(project, it).foldTextOfLevel(3)) } - .filter { it.content.lines().count(String::isBlank) <= 50 } - .takeWhile(::checkAndUpdateTokenCount) - .joinToString(separator = "") {snippet -> - val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { - "$commentPrefix $it" + return runInEdtAndGet { + project.getService(RecentFilesTracker::class.java).getRecentFiles().asSequence() + .filter { it.isFile && it.path != filepath } + .map { CodeSnippet(it.path, getPsiFile(project, it).foldTextOfLevel(2)) } + .filter { it.content.lines().count(String::isBlank) <= 50 } + .takeWhile(::checkAndUpdateTokenCount) + .joinToString(separator = "") {snippet -> + val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { + "$commentPrefix $it" + } + "$commentPrefix Recently open file:\n\n" + + "$commentPrefix ${snippet.filepath}\n\n" + + "$commentedContent\n\n\n\n" } - "$commentPrefix Recently open file:\n\n" + - "$commentPrefix ${snippet.filepath}\n\n" + - "$commentedContent\n\n\n\n" - } + } } fun createPrompt(model: String?): String { From e225fca8385e468b8c11e7bbd47b0963fd96b181 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 23:02:15 +0800 Subject: [PATCH 6/8] Adjust function call context format & measure time taken --- .../kotlin/ai/devchat/common/Constants.kt | 12 ++++++++ src/main/kotlin/ai/devchat/common/IDEUtils.kt | 28 +++++++++++++------ .../plugin/completion/agent/ContextBuilder.kt | 12 +++++--- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/main/kotlin/ai/devchat/common/Constants.kt b/src/main/kotlin/ai/devchat/common/Constants.kt index 2cd4701..b56f3d0 100644 --- a/src/main/kotlin/ai/devchat/common/Constants.kt +++ b/src/main/kotlin/ai/devchat/common/Constants.kt @@ -24,4 +24,16 @@ object Constants { "PHP_METHOD_REFERENCE", // PHP "CALL", // Ruby ) + val LANGUAGE_COMMENT_PREFIX: Map = mapOf( + "kotlin" to "//", + "java" to "//", + "cpp" to "//", + "python" to "#", + "go" to "//", + "javascript" to "//", + "typescript" to "//", + "php" to "//", // PHP also supports `#` for comments + "rust" to "//", + "ruby" to "#" + ) } \ No newline at end of file diff --git a/src/main/kotlin/ai/devchat/common/IDEUtils.kt b/src/main/kotlin/ai/devchat/common/IDEUtils.kt index 0dca447..0b56916 100644 --- a/src/main/kotlin/ai/devchat/common/IDEUtils.kt +++ b/src/main/kotlin/ai/devchat/common/IDEUtils.kt @@ -17,6 +17,7 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking import java.util.concurrent.CompletableFuture import java.util.concurrent.CountDownLatch +import kotlin.system.measureTimeMillis object IDEUtils { @@ -146,16 +147,25 @@ object IDEUtils { val fileNode = file.node ?: return text val foldingBuilder = LanguageFolding.INSTANCE.forLanguage(this.language) ?: return text - val descriptors = foldingBuilder.buildFoldRegions(fileNode, document) - .filter { - textRange.contains(it.range) + var descriptors: List = listOf() + var timeTaken = measureTimeMillis { + descriptors = foldingBuilder.buildFoldRegions(fileNode, document) + .filter { + textRange.contains(it.range) // && it.element.textRange.startOffset > textRange.startOffset // Exclude the function itself - } - .sortedBy { it.range.startOffset } - .let { - findDescriptorsOfFoldingLevel(it, foldingLevel) - } - return foldTextByDescriptors(descriptors) + } + .sortedBy { it.range.startOffset } + .let { + findDescriptorsOfFoldingLevel(it, foldingLevel) + } + } + Log.info("=============> [$this] Time taken to build fold regions: $timeTaken ms, ${file.virtualFile.path}") + var result = "" + timeTaken = measureTimeMillis { + result = foldTextByDescriptors(descriptors) + } + Log.info("=============> [$this] Time taken to fold text: $timeTaken ms, ${file.virtualFile.path}") + return result } private fun findDescriptorsOfFoldingLevel( diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index cec5481..e07beb9 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -1,5 +1,6 @@ package ai.devchat.plugin.completion.agent +import ai.devchat.common.Constants.LANGUAGE_COMMENT_PREFIX import ai.devchat.common.IDEUtils.findAccessibleVariables import ai.devchat.common.IDEUtils.findCalleeInParent import ai.devchat.common.IDEUtils.foldTextOfLevel @@ -70,8 +71,7 @@ data class CodeSnippet ( class ContextBuilder(val file: PsiFile, val offset: Int) { val filepath: String = file.virtualFile.path val content: String = file.text - // TODO: get comment prefix for different languages - private val commentPrefix: String = "//" + private val commentPrefix: String = LANGUAGE_COMMENT_PREFIX[file.language.id.lowercase()] ?: "//" private var tokenCount: Int = 0 private fun buildFileContext(): Pair { @@ -115,8 +115,12 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { ?.flatMap { elements -> elements.filter { it.containingFile.virtualFile.path != filepath } } ?.map { CodeSnippet(it.containingFile.virtualFile.path, it.foldTextOfLevel(1)) } ?.takeWhile(::checkAndUpdateTokenCount) - ?.joinToString(separator = "") { - "$commentPrefixcall function define:\n\n${it.filepath}\n\n${it.content}\n\n\n\n" + ?.joinToString(separator = "") {snippet -> + val commentedContent = snippet.content.lines() + .joinToString(LINE_SEPARATOR.toString()) { "$commentPrefix $it" } + "$commentPrefix Function call definition:\n\n" + + "$commentPrefix ${snippet.filepath}\n\n" + + "$commentPrefix \n$commentedContent\n\n\n\n" } ?: "" } } From e7f8c453d3f9eed4b3e402df26310a6b349fea39 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 23:05:39 +0800 Subject: [PATCH 7/8] Add missing comment prefixes --- .../ai/devchat/plugin/completion/agent/ContextBuilder.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index e07beb9..ce4736b 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -186,8 +186,8 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { ).joinToString("") Log.info("Extras completion context:\n$extras") return if (!model.isNullOrEmpty() && model.contains("deepseek")) - "<|fim▁begin|>$extras$filepath\n\n$prefix<|fim▁hole|>$suffix<|fim▁end|>" + "<|fim▁begin|>$extras$commentPrefix$filepath\n\n$prefix<|fim▁hole|>$suffix<|fim▁end|>" else - "$extras$filepath\n\n$prefix$suffix" + "$extras$commentPrefix$filepath\n\n$prefix$suffix" } } \ No newline at end of file From 61838764dd2a741a9ce4e54c62296524e7e78a62 Mon Sep 17 00:00:00 2001 From: Luo Tim Date: Tue, 8 Oct 2024 23:08:19 +0800 Subject: [PATCH 8/8] Adjust code format --- .../devchat/plugin/completion/agent/ContextBuilder.kt | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt index ce4736b..5c7e228 100644 --- a/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt +++ b/src/main/kotlin/ai/devchat/plugin/completion/agent/ContextBuilder.kt @@ -143,9 +143,8 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { } ?.takeWhile { checkAndUpdateTokenCount(it.second) } ?.joinToString(separator = "") {(name, snippet) -> - val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { - "$commentPrefix $it" - } + val commentedContent = snippet.content.lines() + .joinToString(LINE_SEPARATOR.toString()) { "$commentPrefix $it" } "$commentPrefix Symbol type definition:\n\n" + "$commentPrefix ${name}\n\n" + "$commentPrefix ${snippet.filepath}\n\n" + @@ -163,9 +162,8 @@ class ContextBuilder(val file: PsiFile, val offset: Int) { .filter { it.content.lines().count(String::isBlank) <= 50 } .takeWhile(::checkAndUpdateTokenCount) .joinToString(separator = "") {snippet -> - val commentedContent = snippet.content.lines().joinToString(LINE_SEPARATOR.toString()) { - "$commentPrefix $it" - } + val commentedContent = snippet.content.lines() + .joinToString(LINE_SEPARATOR.toString()) { "$commentPrefix $it" } "$commentPrefix Recently open file:\n\n" + "$commentPrefix ${snippet.filepath}\n\n" + "$commentedContent\n\n\n\n"