diff --git a/pjrt-plugins/xla-cpu/PjrtPluginXlaCpu.idr b/pjrt-plugins/xla-cpu/PjrtPluginXlaCpu.idr index 830c105c5..788d19eb2 100644 --- a/pjrt-plugins/xla-cpu/PjrtPluginXlaCpu.idr +++ b/pjrt-plugins/xla-cpu/PjrtPluginXlaCpu.idr @@ -27,4 +27,5 @@ export device : Pjrt Device device = do api <- MkPjrtApi <$> primIO prim__getPjrtApi + pjrtPluginInitialize api MkDevice api <$> pjrtClientCreate api diff --git a/pjrt-plugins/xla-cuda/PjrtPluginXlaCuda.idr b/pjrt-plugins/xla-cuda/PjrtPluginXlaCuda.idr index 0760f9d75..3e1e24ecd 100644 --- a/pjrt-plugins/xla-cuda/PjrtPluginXlaCuda.idr +++ b/pjrt-plugins/xla-cuda/PjrtPluginXlaCuda.idr @@ -27,4 +27,5 @@ export device : Pjrt Device device = do api <- MkPjrtApi <$> primIO prim__getPjrtApi + pjrtPluginInitialize api MkDevice api <$> pjrtClientCreate api diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index ceddfb28f..e3b86dd9c 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.15 +0.0.16 diff --git a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp index 18290860c..040efb300 100644 --- a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -68,6 +68,17 @@ extern "C" { return api->PJRT_Error_GetCode(args); } + PJRT_Plugin_Initialize_Args* PJRT_Plugin_Initialize_Args_new() { + return new PJRT_Plugin_Initialize_Args{ + .struct_size = PJRT_Plugin_Initialize_Args_STRUCT_SIZE, + .extension_start = nullptr, + }; + } + + PJRT_Error* pjrt_plugin_initialize(PJRT_Api* api, PJRT_Plugin_Initialize_Args* args) { + return api->PJRT_Plugin_Initialize(args); + } + PJRT_Event_Destroy_Args* PJRT_Event_Destroy_Args_new(PJRT_Event* event) { return new PJRT_Event_Destroy_Args{ .struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE, diff --git a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr index 734f9fa16..9933b6c30 100644 --- a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr +++ b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr @@ -154,6 +154,23 @@ try api err onOk = if (isNullPtr err) then right onOk else do destroyPjrtError api err left $ MkPjrtError msg $ map pjrtErrorCodeFromCInt code +%foreign (libxla "PJRT_Plugin_Initialize_Args_new") +prim__mkPjrtPluginInitializeArgs : PrimIO AnyPtr + +%foreign (libxla "pjrt_plugin_initialize") +prim__pjrtPluginInitialize : AnyPtr -> AnyPtr -> PrimIO AnyPtr + +||| For use by plugin developers. +||| +||| Initialize a PJRT plugin. Must be called before the PjrtApi is used. +export +pjrtPluginInitialize : PjrtApi -> Pjrt () +pjrtPluginInitialize (MkPjrtApi api) = do + args <- primIO prim__mkPjrtPluginInitializeArgs + err <- primIO $ prim__pjrtPluginInitialize api args + free args + try api err () + ||| For internal spidr use only. export data PjrtEvent = MkPjrtEvent AnyPtr