Skip to content

Commit

Permalink
initialize PJRT plugins (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley authored Jan 12, 2025
1 parent 475feaa commit 5d449d5
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions pjrt-plugins/xla-cpu/PjrtPluginXlaCpu.idr
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ export
device : Pjrt Device
device = do
api <- MkPjrtApi <$> primIO prim__getPjrtApi
pjrtPluginInitialize api
MkDevice api <$> pjrtClientCreate api
1 change: 1 addition & 0 deletions pjrt-plugins/xla-cuda/PjrtPluginXlaCuda.idr
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ export
device : Pjrt Device
device = do
api <- MkPjrtApi <$> primIO prim__getPjrtApi
pjrtPluginInitialize api
MkDevice api <$> pjrtClientCreate api
2 changes: 1 addition & 1 deletion spidr/backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.15
0.0.16
11 changes: 11 additions & 0 deletions spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ extern "C" {
return api->PJRT_Error_GetCode(args);
}

PJRT_Plugin_Initialize_Args* PJRT_Plugin_Initialize_Args_new() {
return new PJRT_Plugin_Initialize_Args{
.struct_size = PJRT_Plugin_Initialize_Args_STRUCT_SIZE,
.extension_start = nullptr,
};
}

PJRT_Error* pjrt_plugin_initialize(PJRT_Api* api, PJRT_Plugin_Initialize_Args* args) {
return api->PJRT_Plugin_Initialize(args);
}

PJRT_Event_Destroy_Args* PJRT_Event_Destroy_Args_new(PJRT_Event* event) {
return new PJRT_Event_Destroy_Args{
.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE,
Expand Down
17 changes: 17 additions & 0 deletions spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ try api err onOk = if (isNullPtr err) then right onOk else do
destroyPjrtError api err
left $ MkPjrtError msg $ map pjrtErrorCodeFromCInt code

%foreign (libxla "PJRT_Plugin_Initialize_Args_new")
prim__mkPjrtPluginInitializeArgs : PrimIO AnyPtr

%foreign (libxla "pjrt_plugin_initialize")
prim__pjrtPluginInitialize : AnyPtr -> AnyPtr -> PrimIO AnyPtr

||| For use by plugin developers.
|||
||| Initialize a PJRT plugin. Must be called before the PjrtApi is used.
export
pjrtPluginInitialize : PjrtApi -> Pjrt ()
pjrtPluginInitialize (MkPjrtApi api) = do
args <- primIO prim__mkPjrtPluginInitializeArgs
err <- primIO $ prim__pjrtPluginInitialize api args
free args
try api err ()

||| For internal spidr use only.
export
data PjrtEvent = MkPjrtEvent AnyPtr
Expand Down

0 comments on commit 5d449d5

Please sign in to comment.