Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance code completion, error handling, and performance #212

Merged
merged 10 commits into from
Nov 8, 2024
151 changes: 140 additions & 11 deletions src/main/kotlin/ai/devchat/common/IDEUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,39 @@ import kotlinx.coroutines.runBlocking
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CountDownLatch
import kotlin.system.measureTimeMillis
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.codeInsight.navigation.actions.GotoTypeDeclarationAction
import com.intellij.openapi.fileEditor.FileEditorManager
import java.lang.ref.SoftReference
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
import com.intellij.psi.SmartPointerManager
import com.intellij.psi.SmartPsiElementPointer


object IDEUtils {
private const val MAX_CACHE_SIZE = 1000
private data class CacheEntry(val filePath: String, val offset: Int, val element: SoftReference<SymbolTypeDeclaration>)

private val variableCache = object : LinkedHashMap<String, CacheEntry>(MAX_CACHE_SIZE, 0.75f, true) {
override fun removeEldestEntry(eldest: Map.Entry<String, CacheEntry>): Boolean {
return size > MAX_CACHE_SIZE
}
}
private val cacheLock = ReentrantReadWriteLock()

private data class FoldCacheEntry(
val foldedText: String,
val elementPointer: SmartPsiElementPointer<PsiElement>,
val elementLength: Int,
val elementHash: Int
)

private val foldCache = ConcurrentHashMap<String, SoftReference<FoldCacheEntry>>()


fun <T> runInEdtAndGet(block: () -> T): T {
val app = ApplicationManager.getApplication()
if (app.isDispatchThread) {
Expand Down Expand Up @@ -127,21 +157,120 @@ object IDEUtils {
)

fun PsiElement.findAccessibleVariables(): Sequence<SymbolTypeDeclaration> {
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
return generateSequence(this.parent) { it.parent }
.takeWhile { it !is PsiFile }
.flatMap { it.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>() }
.plus(this.containingFile.children.asSequence().filterIsInstance<PsiNameIdentifierOwner>())
.filter { !it.name.isNullOrEmpty() && it.nameIdentifier != null }
.mapNotNull {
val typeDeclaration = it.getTypeDeclaration() ?: return@mapNotNull null
val virtualFile = typeDeclaration.containingFile.virtualFile ?: return@mapNotNull null
val isProjectContent = projectFileIndex.isInContent(virtualFile)
SymbolTypeDeclaration(it, CodeNode(typeDeclaration, isProjectContent))
val projectFileIndex = ProjectFileIndex.getInstance(project)

// 首先收集所有可能的变量
val allVariables = sequence {
var currentScope: PsiElement? = this@findAccessibleVariables
while (currentScope != null && currentScope !is PsiFile) {
val variablesInScope = PsiTreeUtil.findChildrenOfAnyType(
currentScope,
false,
PsiNameIdentifierOwner::class.java
)

for (variable in variablesInScope) {
if (isLikelyVariable(variable) && !variable.name.isNullOrEmpty() && variable.nameIdentifier != null) {
yield(variable)
}
}

currentScope = currentScope.parent
}

yieldAll([email protected]
.asSequence()
.filterIsInstance<PsiNameIdentifierOwner>()
.filter { isLikelyVariable(it) && !it.name.isNullOrEmpty() && it.nameIdentifier != null })
}.distinct()

// 处理这些变量的类型,使用缓存
return allVariables.mapNotNull { variable ->
val cacheKey = "${variable.containingFile?.virtualFile?.path}:${variable.textRange.startOffset}"

getCachedOrCompute(cacheKey, variable)
}
}

private fun getCachedOrCompute(cacheKey: String, variable: PsiElement): SymbolTypeDeclaration? {
cacheLock.read {
variableCache[cacheKey]?.let { entry ->
entry.element.get()?.let { cached ->
if (cached.symbol.isValid) return cached
}
}
}

val computed = computeSymbolTypeDeclaration(variable) ?: return null

cacheLock.write {
variableCache[cacheKey] = CacheEntry(
variable.containingFile?.virtualFile?.path ?: return null,
variable.textRange.startOffset,
SoftReference(computed)
)
}

return computed
}

private fun computeSymbolTypeDeclaration(variable: PsiElement): SymbolTypeDeclaration? {
val typeDeclaration = getTypeElement(variable) ?: return null
val virtualFile = variable.containingFile?.virtualFile ?: return null
val isProjectContent = ProjectFileIndex.getInstance(variable.project).isInContent(virtualFile)
return SymbolTypeDeclaration(variable as PsiNameIdentifierOwner, CodeNode(typeDeclaration, isProjectContent))
}

// 辅助函数,用于判断一个元素是否可能是变量
private fun isLikelyVariable(element: PsiElement): Boolean {
val elementClass = element.javaClass.simpleName
return elementClass.contains("Variable", ignoreCase = true) ||
elementClass.contains("Parameter", ignoreCase = true) ||
elementClass.contains("Field", ignoreCase = true)
}

// 辅助函数,用于获取变量的类型元素
private fun getTypeElement(element: PsiElement): PsiElement? {
return ReadAction.compute<PsiElement?, Throwable> {
val project = element.project
val editor = FileEditorManager.getInstance(project).selectedTextEditor ?: return@compute null
val offset = element.textOffset

GotoTypeDeclarationAction.findSymbolType(editor, offset)
}
}

fun PsiElement.foldTextOfLevel(foldingLevel: Int = 1): String {
var result: String
val executionTime = measureTimeMillis {
val cacheKey = "${containingFile.virtualFile.path}:${textRange.startOffset}:$foldingLevel"

// 检查缓存
result = foldCache[cacheKey]?.get()?.let { cachedEntry ->
val cachedElement = cachedEntry.elementPointer.element
if (cachedElement != null && cachedElement.isValid &&
text.length == cachedEntry.elementLength &&
text.hashCode() == cachedEntry.elementHash) {
cachedEntry.foldedText
} else null
} ?: run {
// 如果缓存无效或不存在,重新计算
val foldedText = computeFoldedText(foldingLevel)
// 更新缓存
val elementPointer = SmartPointerManager.getInstance(project).createSmartPsiElementPointer(this)
foldCache[cacheKey] = SoftReference(FoldCacheEntry(foldedText, elementPointer, text.length, text.hashCode()))
foldedText
}
}

// 记录执行时间
Log.info("foldTextOfLevel execution time: $executionTime ms")

// 返回计算结果
return result
}

private fun PsiElement.computeFoldedText(foldingLevel: Int): String {
val file = this.containingFile
val document = file.viewProvider.document ?: return text
val fileNode = file.node ?: return text
Expand Down
1 change: 0 additions & 1 deletion src/main/kotlin/ai/devchat/plugin/IDEServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ class IDEServer(private var project: Project): Disposable {

fun stop() {
Log.info("Stopping IDE server...")
Notifier.info("Stopping IDE server...")
server?.stop(1_000, 2_000)
}

Expand Down
73 changes: 44 additions & 29 deletions src/main/kotlin/ai/devchat/plugin/completion/agent/Agent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -156,37 +156,52 @@ class Agent(val scope: CoroutineScope) {
}
}

private fun requestDevChatAPI(prompt: String): Flow<CodeCompletionChunk> = flow {
val devChatEndpoint = CONFIG["providers.devchat.api_base"] as? String
val devChatAPIKey = CONFIG["providers.devchat.api_key"] as? String
val endpoint = "$devChatEndpoint/completions"
val endingChunk = "[DONE]"
val payload = mapOf(
"model" to ((CONFIG["complete_model"] as? String) ?: defaultCompletionModel),
"prompt" to prompt,
"stream" to true,
"stop" to listOf("<|endoftext|>", "<|EOT|>", "<file_sep>", "```", "/", "\n\n"),
"temperature" to 0.2
)
val requestBody = gson.toJson(payload).toRequestBody("application/json; charset=utf-8".toMediaType())
val requestBuilder = Request.Builder().url(endpoint).post(requestBody)
requestBuilder.addHeader("Authorization", "Bearer $devChatAPIKey")
requestBuilder.addHeader("Accept", "text/event-stream")
requestBuilder.addHeader("Content-Type", "application/json")
httpClient.newCall(requestBuilder.build()).execute().use { response ->
if (!response.isSuccessful) throw IllegalArgumentException("Unexpected code $response")
response.body?.charStream()?.buffered()?.use {reader ->
reader.lineSequence().asFlow()
.filter {it.isNotEmpty()}
.takeWhile { it.startsWith("data:") }
.map { it.drop(5).trim() }
.takeWhile { it.uppercase() != endingChunk }
.map { gson.fromJson(it, CompletionResponseChunk::class.java) }
.takeWhile {it != null}
.collect { emit(CodeCompletionChunk(it.id, it.choices[0].text!!)) }
private fun requestDevChatAPI(prompt: String): Flow<CodeCompletionChunk> = flow {
val devChatEndpoint = CONFIG["providers.devchat.api_base"] as? String
val devChatAPIKey = CONFIG["providers.devchat.api_key"] as? String
val endpoint = "$devChatEndpoint/completions"
val endingChunk = "[DONE]"
val payload = mapOf(
"model" to ((CONFIG["complete_model"] as? String) ?: defaultCompletionModel),
"prompt" to prompt,
"stream" to true,
"stop" to listOf("<|endoftext|>", "<|EOT|>", "<file_sep>", "```", "/", "\n\n"),
"temperature" to 0.2
)
val requestBody = gson.toJson(payload).toRequestBody("application/json; charset=utf-8".toMediaType())
val requestBuilder = Request.Builder().url(endpoint).post(requestBody)
requestBuilder.addHeader("Authorization", "Bearer $devChatAPIKey")
requestBuilder.addHeader("Accept", "text/event-stream")
requestBuilder.addHeader("Content-Type", "application/json")

httpClient.newCall(requestBuilder.build()).execute().use { response ->
if (!response.isSuccessful) {
val errorBody = response.body?.string() ?: "No error body"
when (response.code) {
500 -> {
if (errorBody.contains("Insufficient Balance")) {
logger.warn("DevChat API error: Insufficient balance. Please check your account.")
} else {
logger.warn("DevChat API server error. Response code: ${response.code}. Body: $errorBody")
}
}
else -> logger.warn("Unexpected response from DevChat API. Code: ${response.code}. Body: $errorBody")
}
return@flow
}

response.body?.charStream()?.buffered()?.use { reader ->
reader.lineSequence().asFlow()
.filter { it.isNotEmpty() }
.takeWhile { it.startsWith("data:") }
.map { it.drop(5).trim() }
.takeWhile { it.uppercase() != endingChunk }
.map { gson.fromJson(it, CompletionResponseChunk::class.java) }
.takeWhile { it != null }
.collect { emit(CodeCompletionChunk(it.id, it.choices[0].text!!)) }
}
}
}

private fun toLines(chunks: Flow<CodeCompletionChunk>): Flow<CodeCompletionChunk> = flow {
var ongoingLine = ""
Expand Down Expand Up @@ -299,7 +314,7 @@ suspend fun provideCompletions(
val llmRequestElapse = System.currentTimeMillis() - startTime
val offset = completionRequest.position
val replaceRange = CompletionResponse.Choice.Range(start = offset, end = offset)
val text = if (completion.text != prevCompletion) completion.text else ""
val text = completion.text
val choice = CompletionResponse.Choice(index = 0, text = text, replaceRange = replaceRange)
val response = CompletionResponse(completion.id, model, listOf(choice), promptBuildingElapse, llmRequestElapse)

Expand Down
65 changes: 50 additions & 15 deletions src/main/kotlin/ai/devchat/plugin/completion/agent/AgentService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,63 @@ import com.intellij.psi.PsiFile
import io.ktor.util.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.ModalityState
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.coroutines.resume
import kotlinx.coroutines.CancellationException

@Service
class AgentService : Disposable {
val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
private var agent: Agent = Agent(scope)

suspend fun provideCompletion(editor: Editor, offset: Int, manually: Boolean = false): Agent.CompletionResponse? {
return ReadAction.compute<PsiFile, Throwable> {
editor.project?.let { project ->
PsiDocumentManager.getInstance(project).getPsiFile(editor.document)
}
}?.let { file ->
agent.provideCompletions(
Agent.CompletionRequest(
file,
file.getLanguageId(),
offset,
manually,
)
)
suspend fun provideCompletion(editor: Editor, offset: Int, manually: Boolean = false): Agent.CompletionResponse? {
println("Entering provideCompletion method")
return withContext(Dispatchers.Default) {
try {
println("Attempting to get PsiFile")
val file = suspendCancellableCoroutine<PsiFile?> { continuation ->
ApplicationManager.getApplication().invokeLater({
val psiFile = ReadAction.compute<PsiFile?, Throwable> {
editor.project?.let { project ->
PsiDocumentManager.getInstance(project).getPsiFile(editor.document)
}
}
continuation.resume(psiFile)
}, ModalityState.defaultModalityState())
}

println("PsiFile obtained: ${file != null}")

file?.let { psiFile ->
println("Calling agent.provideCompletions")
val result = agent.provideCompletions(
Agent.CompletionRequest(
psiFile,
psiFile.getLanguageId(),
offset,
manually,
)
)
println("agent.provideCompletions returned: $result")
result
}
} catch (e: CancellationException) {
// 方案1:以较低的日志级别记录
println("Completion was cancelled: ${e.message}")
// 或者方案2:完全忽略
// // 不做任何处理

null
} catch (e: Exception) {
println("Exception in provideCompletion: ${e.message}")
e.printStackTrace()
null
}
}
}
}

suspend fun postEvent(event: Agent.LogEventRequest) {
agent.postEvent(event)
Expand Down
Loading
Loading