diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
index 3b1451333..ddb1e3352 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt
@@ -15,13 +15,44 @@ abstract class CachedTool(
private val timeCachePolicy: Duration = 1.days
) : Tool {
- override suspend fun invoke(input: Input): Output {
- return cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
- }
+ /**
+ * Logic to be executed when the cache is missed.
+ *
+ * @return the output.
+ */
+ abstract suspend fun onCacheMissed(input: Input): Output
+
+ /**
+ * Criteria to check if the cache should be used for the given [input]. By default, it returns
+ * true, meaning always use the cache if available.
+ *
+ * @return true if the cache should be used.
+ */
+ open suspend fun shouldUseCache(input: Input): Boolean = true
+
+ /**
+ * Criteria to check if the result should be cached based on the given [input] and [output]. By
+ * default, it returns true, meaning always cache the result.
+ *
+ * @return true if the result should be cached.
+ */
+ open suspend fun shouldCacheOutput(input: Input, output: Output): Boolean = true
+
+ /**
+ * Caches the result of [onCacheMissed] if [shouldCacheOutput] returns true. Otherwise, returns
+ * the result of [onCacheMissed].
+ *
+ * @return the output.
+ */
+ override suspend fun invoke(input: Input): Output =
+ if (shouldUseCache(input)) cache(CachedToolKey(input, seed)) { onCacheMissed(input) }
+ else onCacheMissed(input)
/**
* Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and
* [timeCachePolicy]. Removes expired cache entries.
+ *
+ * @return the map of input to output.
*/
suspend fun getCache(): Map {
val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds
@@ -42,8 +73,6 @@ abstract class CachedTool(
return withoutExpired.map { it.key.value to it.value.value }.toMap()
}
- abstract suspend fun onCacheMissed(input: Input): Output
-
private suspend fun cache(input: CachedToolKey, block: suspend () -> Output): Output {
val cachedToolInfo = cache.get()[input]
if (cachedToolInfo != null) {
@@ -55,7 +84,9 @@ abstract class CachedTool(
}
}
val response = block()
- cache.get()[input] = CachedToolValue(response, timeInMillis())
+ if (shouldCacheOutput(input.value, response)) {
+ cache.get()[input] = CachedToolValue(response, timeInMillis())
+ }
return response
}
}