diff --git a/src/function_lua.c b/src/function_lua.c index fa9983bf7e..b535528906 100644 --- a/src/function_lua.c +++ b/src/function_lua.c @@ -64,17 +64,14 @@ typedef struct luaFunctionCtx { } luaFunctionCtx; typedef struct loadCtx { - functionLibInfo *li; + list *functions; monotime start_time; size_t timeout; } loadCtx; -typedef struct registerFunctionArgs { - sds name; - sds desc; - luaFunctionCtx *lua_f_ctx; - uint64_t f_flags; -} registerFunctionArgs; +static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + void *compiled_function); /* Hook for FUNCTION LOAD execution. * Used to cancel the execution in case of a timeout (500ms). @@ -93,15 +90,42 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) { } } +static void freeCompiledFunc(ValkeyModuleCtx *module_ctx, + luaEngineCtx *lua_engine_ctx, + void *compiled_func) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); + + compiledFunction *func = compiled_func; + decrRefCount(func->name); + if (func->desc) { + decrRefCount(func->desc); + } + luaEngineFreeFunction(module_ctx, lua_engine_ctx, func->function); + zfree(func); +} + /* - * Compile a given blob and save it on the registry. - * Return a function ctx with Lua ref that allows to later retrieve the - * function from the registry. + * Compile a given script code by generating a set of compiled functions. These + * functions are also saved into the the registry of the Lua environment. + * + * Returns an array of compiled functions. The `compileFunction` struct stores a + * Lua ref that allows to later retrieve the function from the registry. + * In the `out_num_compiled_functions` parameter is returned the size of the + * array. * * Return NULL on compilation error and set the error to the err variable */ -static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) { - int ret = C_ERR; +static compiledFunction **luaEngineCreate(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + char **err) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); + + compiledFunction **compiled_functions = NULL; luaEngineCtx *lua_engine_ctx = engine_ctx; lua_State *lua = lua_engine_ctx->lua; @@ -114,15 +138,15 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size lua_pop(lua, 1); /* pop the metatable */ /* compile the code */ - if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) { - *err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1)); + if (luaL_loadbuffer(lua, code, strlen(code), "@user_function")) { + *err = valkey_asprintf("Error compiling function: %s", lua_tostring(lua, -1)); lua_pop(lua, 1); /* pops the error */ goto done; } serverAssert(lua_isfunction(lua, -1)); loadCtx load_ctx = { - .li = li, + .functions = listCreate(), .start_time = getMonotonicUs(), .timeout = timeout, }; @@ -133,13 +157,31 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size if (lua_pcall(lua, 0, 0, 0)) { errorInfo err_info = {0}; luaExtractErrorInformation(lua, &err_info); - *err = sdscatprintf(sdsempty(), "Error registering functions: %s", err_info.msg); + *err = valkey_asprintf("Error registering functions: %s", err_info.msg); lua_pop(lua, 1); /* pops the error */ luaErrorInformationDiscard(&err_info); + listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD); + listNode *node = NULL; + while ((node = listNext(iter)) != NULL) { + freeCompiledFunc(module_ctx, lua_engine_ctx, listNodeValue(node)); + } + listReleaseIterator(iter); + listRelease(load_ctx.functions); goto done; } - ret = C_OK; + compiled_functions = + zcalloc(sizeof(compiledFunction *) * listLength(load_ctx.functions)); + listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD); + listNode *node = NULL; + *out_num_compiled_functions = 0; + while ((node = listNext(iter)) != NULL) { + compiledFunction *func = listNodeValue(node); + compiled_functions[*out_num_compiled_functions] = func; + (*out_num_compiled_functions)++; + } + listReleaseIterator(iter); + listRelease(load_ctx.functions); done: /* restore original globals */ @@ -152,19 +194,23 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size lua_sethook(lua, NULL, 0, 0); /* Disable hook */ luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL); - return ret; + return compiled_functions; } /* * Invole the give function with the given keys and args */ -static void luaEngineCall(scriptRunCtx *run_ctx, - void *engine_ctx, +static void luaEngineCall(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + functionCtx *func_ctx, void *compiled_function, robj **keys, size_t nkeys, robj **args, size_t nargs) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); + luaEngineCtx *lua_engine_ctx = engine_ctx; lua_State *lua = lua_engine_ctx->lua; luaFunctionCtx *f_ctx = compiled_function; @@ -177,25 +223,38 @@ static void luaEngineCall(scriptRunCtx *run_ctx, serverAssert(lua_isfunction(lua, -1)); + scriptRunCtx *run_ctx = (scriptRunCtx *)func_ctx; luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0); lua_pop(lua, 1); /* Pop error handler */ } -static size_t luaEngineGetUsedMemoy(void *engine_ctx) { +static engineMemoryInfo luaEngineGetMemoryInfo(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); + luaEngineCtx *lua_engine_ctx = engine_ctx; - return luaMemory(lua_engine_ctx->lua); + + return (engineMemoryInfo){ + .used_memory = luaMemory(lua_engine_ctx->lua), + .engine_memory_overhead = zmalloc_size(lua_engine_ctx), + }; } -static size_t luaEngineFunctionMemoryOverhead(void *compiled_function) { +static size_t luaEngineFunctionMemoryOverhead(ValkeyModuleCtx *module_ctx, + void *compiled_function) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); + return zmalloc_size(compiled_function); } -static size_t luaEngineMemoryOverhead(void *engine_ctx) { - luaEngineCtx *lua_engine_ctx = engine_ctx; - return zmalloc_size(lua_engine_ctx); -} +static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + void *compiled_function) { + /* The lua engine is implemented in the core, and not in a Valkey Module */ + serverAssert(module_ctx == NULL); -static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) { luaEngineCtx *lua_engine_ctx = engine_ctx; lua_State *lua = lua_engine_ctx->lua; luaFunctionCtx *f_ctx = compiled_function; @@ -203,26 +262,19 @@ static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) { zfree(f_ctx); } -static void luaRegisterFunctionArgsInitialize(registerFunctionArgs *register_f_args, - sds name, - sds desc, +static void luaRegisterFunctionArgsInitialize(compiledFunction *func, + robj *name, + robj *desc, luaFunctionCtx *lua_f_ctx, uint64_t flags) { - *register_f_args = (registerFunctionArgs){ + *func = (compiledFunction){ .name = name, .desc = desc, - .lua_f_ctx = lua_f_ctx, + .function = lua_f_ctx, .f_flags = flags, }; } -static void luaRegisterFunctionArgsDispose(lua_State *lua, registerFunctionArgs *register_f_args) { - sdsfree(register_f_args->name); - if (register_f_args->desc) sdsfree(register_f_args->desc); - lua_unref(lua, register_f_args->lua_f_ctx->lua_function_ref); - zfree(register_f_args->lua_f_ctx); -} - /* Read function flags located on the top of the Lua stack. * On success, return C_OK and set the flags to 'flags' out parameter * Return C_ERR if encounter an unknown flag. */ @@ -267,10 +319,11 @@ static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) { return ret; } -static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs *register_f_args) { +static int luaRegisterFunctionReadNamedArgs(lua_State *lua, + compiledFunction *func) { char *err = NULL; - sds name = NULL; - sds desc = NULL; + robj *name = NULL; + robj *desc = NULL; luaFunctionCtx *lua_f_ctx = NULL; uint64_t flags = 0; if (!lua_istable(lua, 1)) { @@ -287,14 +340,15 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs err = "named argument key given to server.register_function is not a string"; goto error; } + const char *key = lua_tostring(lua, -2); if (!strcasecmp(key, "function_name")) { - if (!(name = luaGetStringSds(lua, -1))) { + if (!(name = luaGetStringObject(lua, -1))) { err = "function_name argument given to server.register_function must be a string"; goto error; } } else if (!strcasecmp(key, "description")) { - if (!(desc = luaGetStringSds(lua, -1))) { + if (!(desc = luaGetStringObject(lua, -1))) { err = "description argument given to server.register_function must be a string"; goto error; } @@ -335,13 +389,17 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs goto error; } - luaRegisterFunctionArgsInitialize(register_f_args, name, desc, lua_f_ctx, flags); + luaRegisterFunctionArgsInitialize(func, + name, + desc, + lua_f_ctx, + flags); return C_OK; error: - if (name) sdsfree(name); - if (desc) sdsfree(desc); + if (name) decrRefCount(name); + if (desc) decrRefCount(desc); if (lua_f_ctx) { lua_unref(lua, lua_f_ctx->lua_function_ref); zfree(lua_f_ctx); @@ -350,11 +408,12 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs return C_ERR; } -static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctionArgs *register_f_args) { +static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, + compiledFunction *func) { char *err = NULL; - sds name = NULL; + robj *name = NULL; luaFunctionCtx *lua_f_ctx = NULL; - if (!(name = luaGetStringSds(lua, 1))) { + if (!(name = luaGetStringObject(lua, 1))) { err = "first argument to server.register_function must be a string"; goto error; } @@ -369,17 +428,17 @@ static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctio lua_f_ctx = zmalloc(sizeof(*lua_f_ctx)); lua_f_ctx->lua_function_ref = lua_function_ref; - luaRegisterFunctionArgsInitialize(register_f_args, name, NULL, lua_f_ctx, 0); + luaRegisterFunctionArgsInitialize(func, name, NULL, lua_f_ctx, 0); return C_OK; error: - if (name) sdsfree(name); + if (name) decrRefCount(name); luaPushError(lua, err); return C_ERR; } -static int luaRegisterFunctionReadArgs(lua_State *lua, registerFunctionArgs *register_f_args) { +static int luaRegisterFunctionReadArgs(lua_State *lua, compiledFunction *func) { int argc = lua_gettop(lua); if (argc < 1 || argc > 2) { luaPushError(lua, "wrong number of arguments to server.register_function"); @@ -387,33 +446,28 @@ static int luaRegisterFunctionReadArgs(lua_State *lua, registerFunctionArgs *reg } if (argc == 1) { - return luaRegisterFunctionReadNamedArgs(lua, register_f_args); + return luaRegisterFunctionReadNamedArgs(lua, func); } else { - return luaRegisterFunctionReadPositionalArgs(lua, register_f_args); + return luaRegisterFunctionReadPositionalArgs(lua, func); } } static int luaRegisterFunction(lua_State *lua) { - registerFunctionArgs register_f_args = {0}; + compiledFunction *func = zcalloc(sizeof(*func)); loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME); if (!load_ctx) { + zfree(func); luaPushError(lua, "server.register_function can only be called on FUNCTION LOAD command"); return luaError(lua); } - if (luaRegisterFunctionReadArgs(lua, ®ister_f_args) != C_OK) { + if (luaRegisterFunctionReadArgs(lua, func) != C_OK) { + zfree(func); return luaError(lua); } - sds err = NULL; - if (functionLibCreateFunction(register_f_args.name, register_f_args.lua_f_ctx, load_ctx->li, register_f_args.desc, - register_f_args.f_flags, &err) != C_OK) { - luaRegisterFunctionArgsDispose(lua, ®ister_f_args); - luaPushError(lua, err); - sdsfree(err); - return luaError(lua); - } + listAddNodeTail(load_ctx->functions, func); return 0; } @@ -494,16 +548,17 @@ int luaEngineInitEngine(void) { lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the new global table */ lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */ - - engine *lua_engine = zmalloc(sizeof(*lua_engine)); - *lua_engine = (engine){ - .engine_ctx = lua_engine_ctx, - .create = luaEngineCreate, - .call = luaEngineCall, - .get_used_memory = luaEngineGetUsedMemoy, + engineMethods lua_engine_methods = { + .version = VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION, + .create_functions_library = luaEngineCreate, + .call_function = luaEngineCall, .get_function_memory_overhead = luaEngineFunctionMemoryOverhead, - .get_engine_memory_overhead = luaEngineMemoryOverhead, .free_function = luaEngineFreeFunction, + .get_memory_info = luaEngineGetMemoryInfo, }; - return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine); + + return functionsRegisterEngine(LUA_ENGINE_NAME, + NULL, + lua_engine_ctx, + &lua_engine_methods); } diff --git a/src/functions.c b/src/functions.c index feb82d4ab7..0d003f7fac 100644 --- a/src/functions.c +++ b/src/functions.c @@ -31,6 +31,7 @@ #include "sds.h" #include "dict.h" #include "adlist.h" +#include "module.h" #define LOAD_TIMEOUT_MS 500 @@ -117,9 +118,28 @@ static dict *engines = NULL; /* Libraries Ctx. */ static functionsLibCtx *curr_functions_lib_ctx = NULL; +static void setupEngineModuleCtx(engineInfo *ei, client *c) { + if (ei->engineModule != NULL) { + serverAssert(ei->module_ctx != NULL); + moduleScriptingEngineInitContext(ei->module_ctx, ei->engineModule, c); + } +} + +static void teardownEngineModuleCtx(engineInfo *ei) { + if (ei->engineModule != NULL) { + serverAssert(ei->module_ctx != NULL); + moduleFreeContext(ei->module_ctx); + } +} + static size_t functionMallocSize(functionInfo *fi) { - return zmalloc_size(fi) + sdsAllocSize(fi->name) + (fi->desc ? sdsAllocSize(fi->desc) : 0) + - fi->li->ei->engine->get_function_memory_overhead(fi->function); + setupEngineModuleCtx(fi->li->ei, NULL); + size_t size = zmalloc_size(fi) + + sdsAllocSize(fi->name) + + (fi->desc ? sdsAllocSize(fi->desc) : 0) + + fi->li->ei->engine->get_function_memory_overhead(fi->li->ei->module_ctx, fi->function); + teardownEngineModuleCtx(fi->li->ei); + return size; } static size_t libraryMallocSize(functionLibInfo *li) { @@ -141,8 +161,12 @@ static void engineFunctionDispose(void *obj) { if (fi->desc) { sdsfree(fi->desc); } + setupEngineModuleCtx(fi->li->ei, NULL); engine *engine = fi->li->ei->engine; - engine->free_function(engine->engine_ctx, fi->function); + engine->free_function(fi->li->ei->module_ctx, + engine->engine_ctx, + fi->function); + teardownEngineModuleCtx(fi->li->ei); zfree(fi); } @@ -233,6 +257,15 @@ functionsLibCtx *functionsLibCtxCreate(void) { return ret; } +void functionsAddEngineStats(engineInfo *ei) { + serverAssert(curr_functions_lib_ctx != NULL); + dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, ei->name); + if (entry == NULL) { + functionsLibEngineStats *stats = zcalloc(sizeof(*stats)); + dictAdd(curr_functions_lib_ctx->engines_stats, ei->name, stats); + } +} + /* * Creating a function inside the given library. * On success, return C_OK. @@ -242,24 +275,34 @@ functionsLibCtx *functionsLibCtxCreate(void) { * the function will verify that the given name is following the naming format * and return an error if its not. */ -int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err) { - if (functionsVerifyName(name) != C_OK) { - *err = sdsnew("Library names can only contain letters, numbers, or underscores(_) and must be at least one " - "character long"); +static int functionLibCreateFunction(robj *name, + void *function, + functionLibInfo *li, + robj *desc, + uint64_t f_flags, + sds *err) { + serverAssert(name->type == OBJ_STRING); + serverAssert(desc == NULL || desc->type == OBJ_STRING); + + if (functionsVerifyName(name->ptr) != C_OK) { + *err = sdsnew("Function names can only contain letters, numbers, or " + "underscores(_) and must be at least one character long"); return C_ERR; } - if (dictFetchValue(li->functions, name)) { + sds name_sds = sdsdup(name->ptr); + if (dictFetchValue(li->functions, name_sds)) { *err = sdsnew("Function already exists in the library"); + sdsfree(name_sds); return C_ERR; } functionInfo *fi = zmalloc(sizeof(*fi)); *fi = (functionInfo){ - .name = name, + .name = name_sds, .function = function, .li = li, - .desc = desc, + .desc = desc ? sdsdup(desc->ptr) : NULL, .f_flags = f_flags, }; @@ -403,11 +446,24 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l return ret; } -/* Register an engine, should be called once by the engine on startup and give the following: +/* Register an engine, should be called once by the engine on startup and give + * the following: * * - engine_name - name of the engine to register - * - engine_ctx - the engine ctx that should be used by the server to interact with the engine */ -int functionsRegisterEngine(const char *engine_name, engine *engine) { + * + * - engine_module - the valkey module that implements this engine + * + * - engine_ctx - the engine ctx that should be used by the server to interact + * with the engine. + * + * - engine_methods - the struct with the scripting engine callback functions + * pointers. + * + */ +int functionsRegisterEngine(const char *engine_name, + ValkeyModule *engine_module, + engineCtx *engine_ctx, + engineMethods *engine_methods) { sds engine_name_sds = sdsnew(engine_name); if (dictFetchValue(engines, engine_name_sds)) { serverLog(LL_WARNING, "Same engine was registered twice"); @@ -415,6 +471,16 @@ int functionsRegisterEngine(const char *engine_name, engine *engine) { return C_ERR; } + engine *eng = zmalloc(sizeof(engine)); + *eng = (engine){ + .engine_ctx = engine_ctx, + .create = engine_methods->create_functions_library, + .call = engine_methods->call_function, + .get_function_memory_overhead = engine_methods->get_function_memory_overhead, + .free_function = engine_methods->free_function, + .get_memory_info = engine_methods->get_memory_info, + }; + client *c = createClient(NULL); c->flag.deny_blocking = 1; c->flag.script = 1; @@ -422,15 +488,64 @@ int functionsRegisterEngine(const char *engine_name, engine *engine) { engineInfo *ei = zmalloc(sizeof(*ei)); *ei = (engineInfo){ .name = engine_name_sds, - .engine = engine, + .engineModule = engine_module, + .module_ctx = engine_module ? moduleAllocateContext() : NULL, + .engine = eng, .c = c, }; dictAdd(engines, engine_name_sds, ei); - engine_cache_memory += zmalloc_size(ei) + sdsAllocSize(ei->name) + zmalloc_size(engine) + - engine->get_engine_memory_overhead(engine->engine_ctx); + functionsAddEngineStats(ei); + + setupEngineModuleCtx(ei, NULL); + engineMemoryInfo mem_info = eng->get_memory_info(ei->module_ctx, + eng->engine_ctx); + engine_cache_memory += zmalloc_size(ei) + + sdsAllocSize(ei->name) + + zmalloc_size(eng) + + mem_info.engine_memory_overhead; + + teardownEngineModuleCtx(ei); + + return C_OK; +} + +/* Removes a scripting engine from the server. + * + * - engine_name - name of the engine to remove + */ +int functionsUnregisterEngine(const char *engine_name) { + sds engine_name_sds = sdsnew(engine_name); + dictEntry *entry = dictFind(engines, engine_name_sds); + if (entry == NULL) { + serverLog(LL_WARNING, "There's no engine registered with name %s", engine_name); + sdsfree(engine_name_sds); + return C_ERR; + } + + engineInfo *ei = dictGetVal(entry); + + dictIterator *iter = dictGetSafeIterator(curr_functions_lib_ctx->libraries); + while ((entry = dictNext(iter))) { + functionLibInfo *li = dictGetVal(entry); + if (li->ei == ei) { + libraryUnlink(curr_functions_lib_ctx, li); + engineLibraryFree(li); + } + } + dictReleaseIterator(iter); + + zfree(ei->engine); + sdsfree(ei->name); + freeClient(ei->c); + if (ei->engineModule != NULL) { + serverAssert(ei->module_ctx != NULL); + zfree(ei->module_ctx); + } + zfree(ei); + sdsfree(engine_name_sds); return C_OK; } @@ -649,11 +764,19 @@ static void fcallCommandGeneric(client *c, int ro) { } scriptRunCtx run_ctx; - if (scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return; - - engine->call(&run_ctx, engine->engine_ctx, fi->function, c->argv + 3, numkeys, c->argv + 3 + numkeys, + setupEngineModuleCtx(fi->li->ei, run_ctx.original_client); + + engine->call(fi->li->ei->module_ctx, + engine->engine_ctx, + &run_ctx, + fi->function, + c->argv + 3, + numkeys, + c->argv + 3 + numkeys, c->argc - 3 - numkeys); + + teardownEngineModuleCtx(fi->li->ei); scriptResetRun(&run_ctx); } @@ -953,14 +1076,40 @@ void functionFreeLibMetaData(functionsLibMetaData *md) { if (md->engine) sdsfree(md->engine); } +static void freeCompiledFunctions(engineInfo *ei, + compiledFunction **compiled_functions, + size_t num_compiled_functions, + size_t free_function_from_idx) { + setupEngineModuleCtx(ei, NULL); + + for (size_t i = 0; i < num_compiled_functions; i++) { + compiledFunction *func = compiled_functions[i]; + decrRefCount(func->name); + if (func->desc) { + decrRefCount(func->desc); + } + if (i >= free_function_from_idx) { + ei->engine->free_function(ei->module_ctx, + ei->engine->engine_ctx, + func->function); + } + zfree(func); + } + + zfree(compiled_functions); + + teardownEngineModuleCtx(ei); +} + /* Compile and save the given library, return the loaded library name on success * and NULL on failure. In case on failure the err out param is set with relevant error message */ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout) { dictIterator *iter = NULL; dictEntry *entry = NULL; - functionLibInfo *new_li = NULL; functionLibInfo *old_li = NULL; functionsLibMetaData md = {0}; + functionLibInfo *new_li = NULL; + if (functionExtractLibMetaData(code, &md, err) != C_OK) { return NULL; } @@ -990,10 +1139,47 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC } new_li = engineLibraryCreate(md.name, ei, code); - if (engine->create(engine->engine_ctx, new_li, md.code, timeout, err) != C_OK) { + size_t num_compiled_functions = 0; + char *compile_error = NULL; + setupEngineModuleCtx(ei, NULL); + compiledFunction **compiled_functions = + engine->create(ei->module_ctx, + engine->engine_ctx, + md.code, + timeout, + &num_compiled_functions, + &compile_error); + teardownEngineModuleCtx(ei); + if (compiled_functions == NULL) { + serverAssert(num_compiled_functions == 0); + serverAssert(compile_error != NULL); + *err = sdsnew(compile_error); + zfree(compile_error); goto error; } + for (size_t i = 0; i < num_compiled_functions; i++) { + compiledFunction *func = compiled_functions[i]; + int ret = functionLibCreateFunction(func->name, + func->function, + new_li, + func->desc, + func->f_flags, + err); + if (ret == C_ERR) { + freeCompiledFunctions(ei, + compiled_functions, + num_compiled_functions, + i); + goto error; + } + } + + freeCompiledFunctions(ei, + compiled_functions, + num_compiled_functions, + num_compiled_functions); + if (dictSize(new_li->functions) == 0) { *err = sdsnew("No functions registered"); goto error; @@ -1063,6 +1249,7 @@ void functionLoadCommand(client *c) { timeout = 0; } if (!(library_name = functionsCreateWithLibraryCtx(code->ptr, replace, &err, curr_functions_lib_ctx, timeout))) { + serverAssert(err != NULL); addReplyErrorSds(c, err); return; } @@ -1080,7 +1267,11 @@ unsigned long functionsMemory(void) { while ((entry = dictNext(iter))) { engineInfo *ei = dictGetVal(entry); engine *engine = ei->engine; - engines_memory += engine->get_used_memory(engine->engine_ctx); + setupEngineModuleCtx(ei, NULL); + engineMemoryInfo mem_info = engine->get_memory_info(ei->module_ctx, + engine->engine_ctx); + engines_memory += mem_info.used_memory; + teardownEngineModuleCtx(ei); } dictReleaseIterator(iter); @@ -1120,12 +1311,11 @@ size_t functionsLibCtxFunctionsLen(functionsLibCtx *functions_ctx) { int functionsInit(void) { engines = dictCreate(&engineDictType); + curr_functions_lib_ctx = functionsLibCtxCreate(); + if (luaEngineInitEngine() != C_OK) { return C_ERR; } - /* Must be initialized after engines initialization */ - curr_functions_lib_ctx = functionsLibCtxCreate(); - return C_OK; } diff --git a/src/functions.h b/src/functions.h index b199fbd06e..89e39fdc56 100644 --- a/src/functions.h +++ b/src/functions.h @@ -54,53 +54,68 @@ typedef struct functionLibInfo functionLibInfo; +/* ValkeyModule type aliases for scripting engine structs and types. */ +typedef ValkeyModuleScriptingEngineCtx engineCtx; +typedef ValkeyModuleScriptingEngineFunctionCtx functionCtx; +typedef ValkeyModuleScriptingEngineCompiledFunction compiledFunction; +typedef ValkeyModuleScriptingEngineMemoryInfo engineMemoryInfo; +typedef ValkeyModuleScriptingEngineMethods engineMethods; + typedef struct engine { /* engine specific context */ - void *engine_ctx; - - /* Create function callback, get the engine_ctx, and function code - * engine_ctx - opaque struct that was created on engine initialization - * li - library information that need to be provided and when add functions - * code - the library code - * timeout - timeout for the library creation (0 for no timeout) - * err - description of error (if occurred) - * returns C_ERR on error and set err to be the error message */ - int (*create)(void *engine_ctx, functionLibInfo *li, sds code, size_t timeout, sds *err); - - /* Invoking a function, r_ctx is an opaque object (from engine POV). - * The r_ctx should be used by the engine to interaction with the server, + engineCtx *engine_ctx; + + /* Compiles the script code and returns an array of compiled functions + * registered in the script./ + * + * Returns NULL on error and set err to be the error message */ + compiledFunction **(*create)( + ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + char **err); + + /* Invoking a function, func_ctx is an opaque object (from engine POV). + * The func_ctx should be used by the engine to interaction with the server, * such interaction could be running commands, set resp, or set * replication mode */ - void (*call)(scriptRunCtx *r_ctx, - void *engine_ctx, + void (*call)(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + functionCtx *func_ctx, void *compiled_function, robj **keys, size_t nkeys, robj **args, size_t nargs); - /* get current used memory by the engine */ - size_t (*get_used_memory)(void *engine_ctx); + /* free the given function */ + void (*free_function)(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx, + void *compiled_function); /* Return memory overhead for a given function, * such memory is not counted as engine memory but as general * structs memory that hold different information */ - size_t (*get_function_memory_overhead)(void *compiled_function); + size_t (*get_function_memory_overhead)(ValkeyModuleCtx *module_ctx, + void *compiled_function); - /* Return memory overhead for engine (struct size holding the engine)*/ - size_t (*get_engine_memory_overhead)(void *engine_ctx); + /* Get the current used memory by the engine */ + engineMemoryInfo (*get_memory_info)(ValkeyModuleCtx *module_ctx, + engineCtx *engine_ctx); - /* free the given function */ - void (*free_function)(void *engine_ctx, void *compiled_function); } engine; /* Hold information about an engine. * Used on rdb.c so it must be declared here. */ typedef struct engineInfo { - sds name; /* Name of the engine */ - engine *engine; /* engine callbacks that allows to interact with the engine */ - client *c; /* Client that is used to run commands */ + sds name; /* Name of the engine */ + ValkeyModule *engineModule; /* the module that implements the scripting engine */ + ValkeyModuleCtx *module_ctx; /* Scripting engine module context */ + engine *engine; /* engine callbacks that allows to interact with the engine */ + client *c; /* Client that is used to run commands */ } engineInfo; /* Hold information about the specific function. @@ -123,7 +138,12 @@ struct functionLibInfo { sds code; /* Library code */ }; -int functionsRegisterEngine(const char *engine_name, engine *engine_ctx); +int functionsRegisterEngine(const char *engine_name, + ValkeyModule *engine_module, + void *engine_ctx, + engineMethods *engine_methods); +int functionsUnregisterEngine(const char *engine_name); + sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout); unsigned long functionsMemory(void); unsigned long functionsMemoryOverhead(void); @@ -138,8 +158,6 @@ void functionsLibCtxFree(functionsLibCtx *functions_lib_ctx); void functionsLibCtxClear(functionsLibCtx *lib_ctx, void(callback)(dict *)); void functionsLibCtxSwapWithCurrent(functionsLibCtx *new_lib_ctx, int async); -int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err); - int luaEngineInitEngine(void); int functionsInit(void); diff --git a/src/module.c b/src/module.c index 541ae490ab..db493dd8bc 100644 --- a/src/module.c +++ b/src/module.c @@ -62,6 +62,7 @@ #include "crc16_slottable.h" #include "valkeymodule.h" #include "io_threads.h" +#include "functions.h" #include #include #include @@ -879,6 +880,15 @@ void moduleCallCommandUnblockedHandler(client *c) { moduleReleaseTempClient(c); } +/* Allocates the memory necessary to hold the ValkeyModuleCtx structure, and + * returns the pointer to the allocated memory. + * + * Used by the scripting engines implementation to cache the context structure. + */ +ValkeyModuleCtx *moduleAllocateContext(void) { + return (ValkeyModuleCtx *)zcalloc(sizeof(ValkeyModuleCtx)); +} + /* Create a module ctx and keep track of the nesting level. * * Note: When creating ctx for threads (VM_GetThreadSafeContext and @@ -921,6 +931,16 @@ void moduleCreateContext(ValkeyModuleCtx *out_ctx, ValkeyModule *module, int ctx } } +/* Initialize a module context to be used by scripting engines callback + * functions. + */ +void moduleScriptingEngineInitContext(ValkeyModuleCtx *out_ctx, + ValkeyModule *module, + client *client) { + moduleCreateContext(out_ctx, module, VALKEYMODULE_CTX_NONE); + out_ctx->client = client; +} + /* This command binds the normal command invocation with commands * exported by modules. */ void ValkeyModuleCommandDispatcher(client *c) { @@ -13074,6 +13094,60 @@ int VM_RdbSave(ValkeyModuleCtx *ctx, ValkeyModuleRdbStream *stream, int flags) { return VALKEYMODULE_OK; } +/* Registers a new scripting engine in the server. + * + * - `module_ctx`: the module context object. + * + * - `engine_name`: the name of the scripting engine. This name will match + * against the engine name specified in the script header using a shebang. + * + * - `engine_ctx`: engine specific context pointer. + * + * - `engine_methods`: the struct with the scripting engine callback functions + * pointers. + * + * Returns VALKEYMODULE_OK if the engine is successfully registered, and + * VALKEYMODULE_ERR in case some failure occurs. In case of a failure, an error + * message is logged. + */ +int VM_RegisterScriptingEngine(ValkeyModuleCtx *module_ctx, + const char *engine_name, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineMethods *engine_methods) { + serverLog(LL_DEBUG, "Registering a new scripting engine: %s", engine_name); + + if (engine_methods->version > VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION) { + serverLog(LL_WARNING, "The engine implementation version is greater " + "than what this server supports. Server ABI " + "Version: %lu, Engine ABI version: %lu", + VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION, + (unsigned long)engine_methods->version); + return VALKEYMODULE_ERR; + } + + if (functionsRegisterEngine(engine_name, + module_ctx->module, + engine_ctx, + engine_methods) != C_OK) { + return VALKEYMODULE_ERR; + } + + return VALKEYMODULE_OK; +} + +/* Removes the scripting engine from the server. + * + * `engine_name` is the name of the scripting engine. + * + * Returns VALKEYMODULE_OK. + * + */ +int VM_UnregisterScriptingEngine(ValkeyModuleCtx *ctx, const char *engine_name) { + UNUSED(ctx); + functionsUnregisterEngine(engine_name); + return VALKEYMODULE_OK; +} + /* MODULE command. * * MODULE LIST @@ -13944,4 +14018,6 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(RdbStreamFree); REGISTER_API(RdbLoad); REGISTER_API(RdbSave); + REGISTER_API(RegisterScriptingEngine); + REGISTER_API(UnregisterScriptingEngine); } diff --git a/src/module.h b/src/module.h new file mode 100644 index 0000000000..f61ef1e3cb --- /dev/null +++ b/src/module.h @@ -0,0 +1,17 @@ +#ifndef _MODULE_H_ +#define _MODULE_H_ + +/* This header file exposes a set of functions defined in module.c that are + * not part of the module API, but are used by the core to interact with modules + */ + +typedef struct ValkeyModuleCtx ValkeyModuleCtx; +typedef struct ValkeyModule ValkeyModule; + +ValkeyModuleCtx *moduleAllocateContext(void); +void moduleScriptingEngineInitContext(ValkeyModuleCtx *out_ctx, + ValkeyModule *module, + client *client); +void moduleFreeContext(ValkeyModuleCtx *ctx); + +#endif /* _MODULE_H_ */ diff --git a/src/script.h b/src/script.h index 7fff34a40b..194cc8bd05 100644 --- a/src/script.h +++ b/src/script.h @@ -67,6 +67,8 @@ #define SCRIPT_ALLOW_CROSS_SLOT (1ULL << 8) /* Indicate that the current script may access keys from multiple slots */ typedef struct scriptRunCtx scriptRunCtx; +/* This struct stores the necessary information to manage the execution of + * scripts using EVAL and FCALL. */ struct scriptRunCtx { const char *funcname; client *c; diff --git a/src/script_lua.c b/src/script_lua.c index 5093fa944f..29d352d44b 100644 --- a/src/script_lua.c +++ b/src/script_lua.c @@ -1258,15 +1258,15 @@ static void luaLoadLibraries(lua_State *lua) { /* Return sds of the string value located on stack at the given index. * Return NULL if the value is not a string. */ -sds luaGetStringSds(lua_State *lua, int index) { +robj *luaGetStringObject(lua_State *lua, int index) { if (!lua_isstring(lua, index)) { return NULL; } size_t len; const char *str = lua_tolstring(lua, index, &len); - sds str_sds = sdsnewlen(str, len); - return str_sds; + robj *str_obj = createStringObject(str, len); + return str_obj; } static int luaProtectedTableError(lua_State *lua) { diff --git a/src/script_lua.h b/src/script_lua.h index 35edf46af6..6c60754bbc 100644 --- a/src/script_lua.h +++ b/src/script_lua.h @@ -67,7 +67,7 @@ typedef struct errorInfo { } errorInfo; void luaRegisterServerAPI(lua_State *lua); -sds luaGetStringSds(lua_State *lua, int index); +robj *luaGetStringObject(lua_State *lua, int index); void luaRegisterGlobalProtectionFunction(lua_State *lua); void luaSetErrorMetatable(lua_State *lua); void luaSetAllowListProtection(lua_State *lua); diff --git a/src/util.c b/src/util.c index 6d99d47e5a..6e44392ce1 100644 --- a/src/util.c +++ b/src/util.c @@ -50,6 +50,7 @@ #include "util.h" #include "sha256.h" #include "config.h" +#include "zmalloc.h" #include "valkey_strtod.h" @@ -1380,3 +1381,23 @@ int snprintf_async_signal_safe(char *to, size_t n, const char *fmt, ...) { va_end(args); return result; } + +/* A printf-like function that returns a freshly allocated string. + * + * This function is similar to asprintf function, but it uses zmalloc for + * allocating the string buffer. */ +char *valkey_asprintf(char const *fmt, ...) { + va_list args; + + va_start(args, fmt); + size_t str_len = vsnprintf(NULL, 0, fmt, args) + 1; + va_end(args); + + char *str = zmalloc(str_len); + + va_start(args, fmt); + vsnprintf(str, str_len, fmt, args); + va_end(args); + + return str; +} diff --git a/src/util.h b/src/util.h index 51eb38f0b4..61095ddb65 100644 --- a/src/util.h +++ b/src/util.h @@ -99,5 +99,6 @@ int snprintf_async_signal_safe(char *to, size_t n, const char *fmt, ...); #endif size_t valkey_strlcpy(char *dst, const char *src, size_t dsize); size_t valkey_strlcat(char *dst, const char *src, size_t dsize); +char *valkey_asprintf(char const *fmt, ...); #endif diff --git a/src/valkeymodule.h b/src/valkeymodule.h index 7c3adfd477..1d99d2ff7a 100644 --- a/src/valkeymodule.h +++ b/src/valkeymodule.h @@ -783,6 +783,7 @@ typedef enum { } ValkeyModuleACLLogEntryReason; /* Incomplete structures needed by both the core and modules. */ +typedef struct ValkeyModuleCtx ValkeyModuleCtx; typedef struct ValkeyModuleIO ValkeyModuleIO; typedef struct ValkeyModuleDigest ValkeyModuleDigest; typedef struct ValkeyModuleInfoCtx ValkeyModuleInfoCtx; @@ -794,6 +795,93 @@ typedef void (*ValkeyModuleInfoFunc)(ValkeyModuleInfoCtx *ctx, int for_crash_rep typedef void (*ValkeyModuleDefragFunc)(ValkeyModuleDefragCtx *ctx); typedef void (*ValkeyModuleUserChangedFunc)(uint64_t client_id, void *privdata); +/* Current ABI version for scripting engine modules. */ +#define VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION 1UL + +/* Type definitions for implementing scripting engines modules. */ +typedef void ValkeyModuleScriptingEngineCtx; +typedef void ValkeyModuleScriptingEngineFunctionCtx; + +/* This struct represents a scripting engine function that results from the + * compilation of a script by the engine implementation. + * + * IMPORTANT: If we ever need to add/remove fields from this struct, we need + * to bump the version number defined in the + * `VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION` constant. + */ +typedef struct ValkeyModuleScriptingEngineCompiledFunction { + ValkeyModuleString *name; /* Function name */ + void *function; /* Opaque object representing a function, usually it' + the function compiled code. */ + ValkeyModuleString *desc; /* Function description */ + uint64_t f_flags; /* Function flags */ +} ValkeyModuleScriptingEngineCompiledFunction; + +/* This struct is used to return the memory information of the scripting + * engine. */ +typedef struct ValkeyModuleScriptingEngineMemoryInfo { + /* The memory used by the scripting engine runtime. */ + size_t used_memory; + /* The memory used by the scripting engine data structures. */ + size_t engine_memory_overhead; +} ValkeyModuleScriptingEngineMemoryInfo; + +typedef ValkeyModuleScriptingEngineCompiledFunction **(*ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc)( + ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + char **err); + +typedef void (*ValkeyModuleScriptingEngineCallFunctionFunc)( + ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineFunctionCtx *func_ctx, + void *compiled_function, + ValkeyModuleString **keys, + size_t nkeys, + ValkeyModuleString **args, + size_t nargs); + +typedef size_t (*ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc)( + ValkeyModuleCtx *module_ctx, + void *compiled_function); + +typedef void (*ValkeyModuleScriptingEngineFreeFunctionFunc)( + ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + void *compiled_function); + +typedef ValkeyModuleScriptingEngineMemoryInfo (*ValkeyModuleScriptingEngineGetMemoryInfoFunc)( + ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx); + +typedef struct ValkeyModuleScriptingEngineMethodsV1 { + uint64_t version; /* Version of this structure for ABI compat. */ + + /* Library create function callback. When a new script is loaded, this + * callback will be called with the script code, and returns a list of + * ValkeyModuleScriptingEngineCompiledFunc objects. */ + ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc create_functions_library; + + /* Function callback to free the memory of a registered engine function. */ + ValkeyModuleScriptingEngineFreeFunctionFunc free_function; + + /* The callback function called when `FCALL` command is called on a function + * registered in this engine. */ + ValkeyModuleScriptingEngineCallFunctionFunc call_function; + + /* Function callback to return memory overhead for a given function. */ + ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc get_function_memory_overhead; + + /* Function callback to get the used memory by the engine. */ + ValkeyModuleScriptingEngineGetMemoryInfoFunc get_memory_info; + +} ValkeyModuleScriptingEngineMethodsV1; + +#define ValkeyModuleScriptingEngineMethods ValkeyModuleScriptingEngineMethodsV1 + /* ------------------------- End of common defines ------------------------ */ /* ----------- The rest of the defines are only for modules ----------------- */ @@ -826,7 +914,6 @@ typedef void (*ValkeyModuleUserChangedFunc)(uint64_t client_id, void *privdata); #endif /* Incomplete structures for compiler checks but opaque access. */ -typedef struct ValkeyModuleCtx ValkeyModuleCtx; typedef struct ValkeyModuleCommand ValkeyModuleCommand; typedef struct ValkeyModuleCallReply ValkeyModuleCallReply; typedef struct ValkeyModuleType ValkeyModuleType; @@ -1650,6 +1737,14 @@ VALKEYMODULE_API int (*ValkeyModule_RdbSave)(ValkeyModuleCtx *ctx, ValkeyModuleRdbStream *stream, int flags) VALKEYMODULE_ATTR; +VALKEYMODULE_API int (*ValkeyModule_RegisterScriptingEngine)(ValkeyModuleCtx *module_ctx, + const char *engine_name, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineMethods *engine_methods) VALKEYMODULE_ATTR; + +VALKEYMODULE_API int (*ValkeyModule_UnregisterScriptingEngine)(ValkeyModuleCtx *module_ctx, + const char *engine_name) VALKEYMODULE_ATTR; + #define ValkeyModule_IsAOFClient(id) ((id) == UINT64_MAX) /* This is included inline inside each Valkey module. */ @@ -2017,6 +2112,8 @@ static int ValkeyModule_Init(ValkeyModuleCtx *ctx, const char *name, int ver, in VALKEYMODULE_GET_API(RdbStreamFree); VALKEYMODULE_GET_API(RdbLoad); VALKEYMODULE_GET_API(RdbSave); + VALKEYMODULE_GET_API(RegisterScriptingEngine); + VALKEYMODULE_GET_API(UnregisterScriptingEngine); if (ValkeyModule_IsModuleNameBusy && ValkeyModule_IsModuleNameBusy(name)) return VALKEYMODULE_ERR; ValkeyModule_SetModuleAttribs(ctx, name, ver, apiver); diff --git a/tests/modules/CMakeLists.txt b/tests/modules/CMakeLists.txt index 0cac0c4cb6..e98a878c9d 100644 --- a/tests/modules/CMakeLists.txt +++ b/tests/modules/CMakeLists.txt @@ -40,6 +40,7 @@ list(APPEND MODULES_LIST "moduleauthtwo") list(APPEND MODULES_LIST "rdbloadsave") list(APPEND MODULES_LIST "crash") list(APPEND MODULES_LIST "cluster") +list(APPEND MODULES_LIST "helloscripting") foreach (MODULE_NAME ${MODULES_LIST}) message(STATUS "Building test module: ${MODULE_NAME}") diff --git a/tests/modules/Makefile b/tests/modules/Makefile index 82813bb6f7..963546a9ff 100644 --- a/tests/modules/Makefile +++ b/tests/modules/Makefile @@ -65,7 +65,8 @@ TEST_MODULES = \ moduleauthtwo.so \ rdbloadsave.so \ crash.so \ - cluster.so + cluster.so \ + helloscripting.so .PHONY: all diff --git a/tests/modules/helloscripting.c b/tests/modules/helloscripting.c new file mode 100644 index 0000000000..fdca6c8e91 --- /dev/null +++ b/tests/modules/helloscripting.c @@ -0,0 +1,383 @@ +#include "valkeymodule.h" + +#include +#include +#include + +/* + * This module implements a very simple stack based scripting language. + * It's purpose is only to test the valkey module API to implement scripting + * engines. + * + * The language is called HELLO, and a program in this language is formed by + * a list of function definitions. + * The language only supports 32-bit integer, and it only allows to return an + * integer constant, or return the value passed as the first argument to the + * function. + * + * Example of a program: + * + * ``` + * FUNCTION foo # declaration of function 'foo' + * ARGS 0 # pushes the value in the first argument to the top of the + * # stack + * RETURN # returns the current value on the top of the stack and marks + * # the end of the function declaration + * + * FUNCTION bar # declaration of function 'bar' + * CONSTI 432 # pushes the value 432 to the top of the stack + * RETURN # returns the current value on the top of the stack and marks + * # the end of the function declaration. + * ``` + */ + +/* + * List of instructions of the HELLO language. + */ +typedef enum HelloInstKind { + FUNCTION = 0, + CONSTI, + ARGS, + RETURN, + _NUM_INSTRUCTIONS, // Not a real instruction. +} HelloInstKind; + +/* + * String representations of the instructions above. + */ +const char *HelloInstKindStr[] = { + "FUNCTION", + "CONSTI", + "ARGS", + "RETURN", +}; + +/* + * Struct that represents an instance of an instruction. + * Instructions may have at most one parameter. + */ +typedef struct HelloInst { + HelloInstKind kind; + union { + uint32_t integer; + const char *string; + } param; +} HelloInst; + +/* + * Struct that represents an instance of a function. + * A function is just a list of instruction instances. + */ +typedef struct HelloFunc { + char *name; + HelloInst instructions[256]; + uint32_t num_instructions; +} HelloFunc; + +/* + * Struct that represents an instance of an HELLO program. + * A program is just a list of function instances. + */ +typedef struct HelloProgram { + HelloFunc *functions[16]; + uint32_t num_functions; +} HelloProgram; + +/* + * Struct that represents the runtime context of an HELLO program. + */ +typedef struct HelloLangCtx { + HelloProgram *program; +} HelloLangCtx; + + +static HelloLangCtx *hello_ctx = NULL; + + +static uint32_t str2int(const char *str) { + char *end; + errno = 0; + uint32_t val = (uint32_t)strtoul(str, &end, 10); + ValkeyModule_Assert(errno == 0); + return val; +} + +/* + * Parses the kind of instruction that the current token points to. + */ +static HelloInstKind helloLangParseInstruction(const char *token) { + for (HelloInstKind i = 0; i < _NUM_INSTRUCTIONS; i++) { + if (strcmp(HelloInstKindStr[i], token) == 0) { + return i; + } + } + return _NUM_INSTRUCTIONS; +} + +/* + * Parses the function param. + */ +static void helloLangParseFunction(HelloFunc *func) { + char *token = strtok(NULL, " \n"); + ValkeyModule_Assert(token != NULL); + func->name = ValkeyModule_Alloc(sizeof(char) * strlen(token) + 1); + strcpy(func->name, token); +} + +/* + * Parses an integer parameter. + */ +static void helloLangParseIntegerParam(HelloFunc *func) { + char *token = strtok(NULL, " \n"); + func->instructions[func->num_instructions].param.integer = str2int(token); +} + +/* + * Parses the CONSTI instruction parameter. + */ +static void helloLangParseConstI(HelloFunc *func) { + helloLangParseIntegerParam(func); + func->num_instructions++; +} + +/* + * Parses the ARGS instruction parameter. + */ +static void helloLangParseArgs(HelloFunc *func) { + helloLangParseIntegerParam(func); + func->num_instructions++; +} + +/* + * Parses an HELLO program source code. + */ +static HelloProgram *helloLangParseCode(const char *code, + HelloProgram *program) { + char *_code = ValkeyModule_Alloc(sizeof(char) * strlen(code) + 1); + strcpy(_code, code); + + HelloFunc *currentFunc = NULL; + + char *token = strtok(_code, " \n"); + while (token != NULL) { + HelloInstKind kind = helloLangParseInstruction(token); + + if (currentFunc != NULL) { + currentFunc->instructions[currentFunc->num_instructions].kind = kind; + } + + switch (kind) { + case FUNCTION: + ValkeyModule_Assert(currentFunc == NULL); + currentFunc = ValkeyModule_Alloc(sizeof(HelloFunc)); + memset(currentFunc, 0, sizeof(HelloFunc)); + program->functions[program->num_functions++] = currentFunc; + helloLangParseFunction(currentFunc); + break; + case CONSTI: + ValkeyModule_Assert(currentFunc != NULL); + helloLangParseConstI(currentFunc); + break; + case ARGS: + ValkeyModule_Assert(currentFunc != NULL); + helloLangParseArgs(currentFunc); + break; + case RETURN: + ValkeyModule_Assert(currentFunc != NULL); + currentFunc->num_instructions++; + currentFunc = NULL; + break; + default: + ValkeyModule_Assert(0); + } + + token = strtok(NULL, " \n"); + } + + ValkeyModule_Free(_code); + + return program; +} + +/* + * Executes an HELLO function. + */ +static uint32_t executeHelloLangFunction(HelloFunc *func, + ValkeyModuleString **args, int nargs) { + uint32_t stack[64]; + int sp = 0; + + for (uint32_t pc = 0; pc < func->num_instructions; pc++) { + HelloInst instr = func->instructions[pc]; + switch (instr.kind) { + case CONSTI: + stack[sp++] = instr.param.integer; + break; + case ARGS: + uint32_t idx = instr.param.integer; + ValkeyModule_Assert(idx < (uint32_t)nargs); + size_t len; + const char *argStr = ValkeyModule_StringPtrLen(args[idx], &len); + uint32_t arg = str2int(argStr); + stack[sp++] = arg; + break; + case RETURN: + uint32_t val = stack[--sp]; + ValkeyModule_Assert(sp == 0); + return val; + case FUNCTION: + default: + ValkeyModule_Assert(0); + } + } + + ValkeyModule_Assert(0); + return 0; +} + +static ValkeyModuleScriptingEngineMemoryInfo engineGetMemoryInfo(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx) { + VALKEYMODULE_NOT_USED(module_ctx); + HelloLangCtx *ctx = (HelloLangCtx *)engine_ctx; + ValkeyModuleScriptingEngineMemoryInfo mem_info = {0}; + + if (ctx->program != NULL) { + mem_info.used_memory += ValkeyModule_MallocSize(ctx->program); + + for (uint32_t i = 0; i < ctx->program->num_functions; i++) { + HelloFunc *func = ctx->program->functions[i]; + mem_info.used_memory += ValkeyModule_MallocSize(func); + mem_info.used_memory += ValkeyModule_MallocSize(func->name); + } + } + + mem_info.engine_memory_overhead = ValkeyModule_MallocSize(ctx); + if (ctx->program != NULL) { + mem_info.engine_memory_overhead += ValkeyModule_MallocSize(ctx->program); + } + + return mem_info; +} + +static size_t engineFunctionMemoryOverhead(ValkeyModuleCtx *module_ctx, + void *compiled_function) { + VALKEYMODULE_NOT_USED(module_ctx); + HelloFunc *func = (HelloFunc *)compiled_function; + return ValkeyModule_MallocSize(func->name); +} + +static void engineFreeFunction(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + void *compiled_function) { + VALKEYMODULE_NOT_USED(module_ctx); + VALKEYMODULE_NOT_USED(engine_ctx); + HelloFunc *func = (HelloFunc *)compiled_function; + ValkeyModule_Free(func->name); + func->name = NULL; + ValkeyModule_Free(func); +} + +static ValkeyModuleScriptingEngineCompiledFunction **createHelloLangEngine(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + const char *code, + size_t timeout, + size_t *out_num_compiled_functions, + char **err) { + VALKEYMODULE_NOT_USED(module_ctx); + VALKEYMODULE_NOT_USED(timeout); + VALKEYMODULE_NOT_USED(err); + + HelloLangCtx *ctx = (HelloLangCtx *)engine_ctx; + + if (ctx->program == NULL) { + ctx->program = ValkeyModule_Alloc(sizeof(HelloProgram)); + memset(ctx->program, 0, sizeof(HelloProgram)); + } else { + ctx->program->num_functions = 0; + } + + ctx->program = helloLangParseCode(code, ctx->program); + + ValkeyModuleScriptingEngineCompiledFunction **compiled_functions = + ValkeyModule_Alloc(sizeof(ValkeyModuleScriptingEngineCompiledFunction *) * ctx->program->num_functions); + + for (uint32_t i = 0; i < ctx->program->num_functions; i++) { + HelloFunc *func = ctx->program->functions[i]; + + ValkeyModuleScriptingEngineCompiledFunction *cfunc = + ValkeyModule_Alloc(sizeof(ValkeyModuleScriptingEngineCompiledFunction)); + *cfunc = (ValkeyModuleScriptingEngineCompiledFunction) { + .name = ValkeyModule_CreateString(NULL, func->name, strlen(func->name)), + .function = func, + .desc = NULL, + .f_flags = 0, + }; + + compiled_functions[i] = cfunc; + } + + *out_num_compiled_functions = ctx->program->num_functions; + + return compiled_functions; +} + +static void +callHelloLangFunction(ValkeyModuleCtx *module_ctx, + ValkeyModuleScriptingEngineCtx *engine_ctx, + ValkeyModuleScriptingEngineFunctionCtx *func_ctx, + void *compiled_function, + ValkeyModuleString **keys, size_t nkeys, + ValkeyModuleString **args, size_t nargs) { + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(func_ctx); + VALKEYMODULE_NOT_USED(keys); + VALKEYMODULE_NOT_USED(nkeys); + + HelloFunc *func = (HelloFunc *)compiled_function; + uint32_t result = executeHelloLangFunction(func, args, nargs); + + ValkeyModule_ReplyWithLongLong(module_ctx, result); +} + +int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, + int argc) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + + if (ValkeyModule_Init(ctx, "helloengine", 1, VALKEYMODULE_APIVER_1) == + VALKEYMODULE_ERR) + return VALKEYMODULE_ERR; + + hello_ctx = ValkeyModule_Alloc(sizeof(HelloLangCtx)); + hello_ctx->program = NULL; + + ValkeyModuleScriptingEngineMethods methods = { + .version = VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION, + .create_functions_library = createHelloLangEngine, + .call_function = callHelloLangFunction, + .get_function_memory_overhead = engineFunctionMemoryOverhead, + .free_function = engineFreeFunction, + .get_memory_info = engineGetMemoryInfo, + }; + + ValkeyModule_RegisterScriptingEngine(ctx, + "HELLO", + hello_ctx, + &methods); + + return VALKEYMODULE_OK; +} + +int ValkeyModule_OnUnload(ValkeyModuleCtx *ctx) { + if (ValkeyModule_UnregisterScriptingEngine(ctx, "HELLO") != VALKEYMODULE_OK) { + ValkeyModule_Log(ctx, "error", "Failed to unregister engine"); + return VALKEYMODULE_ERR; + } + + ValkeyModule_Free(hello_ctx->program); + hello_ctx->program = NULL; + ValkeyModule_Free(hello_ctx); + hello_ctx = NULL; + + return VALKEYMODULE_OK; +} diff --git a/tests/unit/functions.tcl b/tests/unit/functions.tcl index 7ddd36dd7d..1636baaf6d 100644 --- a/tests/unit/functions.tcl +++ b/tests/unit/functions.tcl @@ -604,7 +604,7 @@ start_server {tags {"scripting"}} { } } e set _ $e - } {*Library names can only contain letters, numbers, or underscores(_) and must be at least one character long*} + } {*Function names can only contain letters, numbers, or underscores(_) and must be at least one character long*} test {LIBRARIES - test registration with empty name} { catch { @@ -613,7 +613,7 @@ start_server {tags {"scripting"}} { } } e set _ $e - } {*Library names can only contain letters, numbers, or underscores(_) and must be at least one character long*} + } {*Function names can only contain letters, numbers, or underscores(_) and must be at least one character long*} test {LIBRARIES - math.random from function load} { catch { diff --git a/tests/unit/moduleapi/scriptingengine.tcl b/tests/unit/moduleapi/scriptingengine.tcl new file mode 100644 index 0000000000..c350633dd8 --- /dev/null +++ b/tests/unit/moduleapi/scriptingengine.tcl @@ -0,0 +1,126 @@ +set testmodule [file normalize tests/modules/helloscripting.so] + +set HELLO_PROGRAM "#!hello name=mylib\nFUNCTION foo\nARGS 0\nRETURN\nFUNCTION bar\nCONSTI 432\nRETURN" + +start_server {tags {"modules"}} { + r module load $testmodule + + r function load $HELLO_PROGRAM + + test {Load script with invalid library name} { + assert_error {ERR Library names can only contain letters, numbers, or underscores(_) and must be at least one character long} {r function load "#!hello name=my-lib\nFUNCTION foo\nARGS 0\nRETURN"} + } + + test {Load script with existing library} { + assert_error {ERR Library 'mylib' already exists} {r function load $HELLO_PROGRAM} + } + + test {Load script with invalid engine} { + assert_error {ERR Engine 'wasm' not found} {r function load "#!wasm name=mylib2\nFUNCTION foo\nARGS 0\nRETURN"} + } + + test {Load script with no functions} { + assert_error {ERR No functions registered} {r function load "#!hello name=mylib2\n"} + } + + test {Load script with duplicate function} { + assert_error {ERR Function foo already exists} {r function load "#!hello name=mylib2\nFUNCTION foo\nARGS 0\nRETURN"} + } + + test {Load script with no metadata header} { + assert_error {ERR Missing library metadata} {r function load "FUNCTION foo\nARGS 0\nRETURN"} + } + + test {Load script with header without lib name} { + assert_error {ERR Library name was not given} {r function load "#!hello \n"} + } + + test {Load script with header with unknown param} { + assert_error {ERR Invalid metadata value given: nme=mylib} {r function load "#!hello nme=mylib\n"} + } + + test {Load script with header with lib name passed twice} { + assert_error {ERR Invalid metadata value, name argument was given multiple times} {r function load "#!hello name=mylib2 name=mylib3\n"} + } + + test {Load script with invalid function name} { + assert_error {ERR Function names can only contain letters, numbers, or underscores(_) and must be at least one character long} {r function load "#!hello name=mylib2\nFUNCTION foo-bar\nARGS 0\nRETURN"} + } + + test {Load script with duplicate function} { + assert_error {ERR Function already exists in the library} {r function load "#!hello name=mylib2\nFUNCTION foo\nARGS 0\nRETURN\nFUNCTION foo\nARGS 0\nRETURN"} + } + + test {Call scripting engine function: calling foo works} { + r fcall foo 0 134 + } {134} + + test {Call scripting engine function: calling bar works} { + r fcall bar 0 + } {432} + + test {Replace function library and call functions} { + set result [r function load replace "#!hello name=mylib\nFUNCTION foo\nARGS 0\nRETURN\nFUNCTION bar\nCONSTI 500\nRETURN"] + assert_equal $result "mylib" + + set result [r fcall foo 0 132] + assert_equal $result 132 + + set result [r fcall bar 0] + assert_equal $result 500 + } + + test {List scripting engine functions} { + r function load replace "#!hello name=mylib\nFUNCTION foobar\nARGS 0\nRETURN" + r function list + } {{library_name mylib engine HELLO functions {{name foobar description {} flags {}}}}} + + test {Load a second library and call a function} { + r function load "#!hello name=mylib2\nFUNCTION getarg\nARGS 0\nRETURN" + set result [r fcall getarg 0 456] + assert_equal $result 456 + } + + test {Delete all libraries and functions} { + set result [r function flush] + assert_equal $result {OK} + r function list + } {} + + test {Test the deletion of a single library} { + r function load $HELLO_PROGRAM + r function load "#!hello name=mylib2\nFUNCTION getarg\nARGS 0\nRETURN" + + set result [r function delete mylib] + assert_equal $result {OK} + + set result [r fcall getarg 0 446] + assert_equal $result 446 + } + + test {Test dump and restore function library} { + r function load $HELLO_PROGRAM + + set result [r fcall bar 0] + assert_equal $result 432 + + set dump [r function dump] + + set result [r function flush] + assert_equal $result {OK} + + set result [r function restore $dump] + assert_equal $result {OK} + + set result [r fcall getarg 0 436] + assert_equal $result 436 + + set result [r fcall bar 0] + assert_equal $result 432 + } + + test {Unload scripting engine module} { + set result [r module unload helloengine] + assert_equal $result "OK" + } +}