Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change rlua to mlua #64

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,29 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install libudev-sys
run: sudo apt-get install -y libudev-dev
- name: Install libudev-sys and Lua5.4
run: sudo apt-get install -y libudev-dev liblua5.4-dev
- name: Build
run: cargo build --verbose
build-windows:
name: Check on Windows
runs-on: windows-latest
defaults:
run:
shell: msys2 {0}
steps:
- uses: msys2/setup-msys2@v2
- uses: actions/checkout@v3
- name: Install Rust & Lua
run: pacman -S --noconfirm mingw-w64-x86_64-rust mingw-w64-x86_64-lua mingw-w64-x86_64-luajit mingw-w64-x86_64-pkg-config
- name: Build
run: cargo build --verbose
build-macos:
name: Check on MacOS
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
- name: Install Lua5.4
run: brew install lua
- name: Build
run: cargo build --verbose
57 changes: 32 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ serde = "1.0"
serde_yaml = "0.9"
rand = "0.8.5"
clap = { version = "4.1.9", features = ["derive"] }
rlua = "0.19.8"
mlua = { version = "0.9.6", features = ["lua54", "async", "send"] }
anyhow = "1.0.79"
homedir = "0.2.1"
regex = { version = "1.10.3", features = [] }
Expand Down
189 changes: 87 additions & 102 deletions src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::text::TextView;
use anyhow::Result;
use chrono::Local;
use homedir::get_my_home;
use rlua::{Context, Function, Lua, RegistryKey, Table, Thread};
use mlua::{Function, Lua, RegistryKey, Table, Thread};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -120,7 +120,8 @@ impl Plugin {
.to_string();
let code = std::fs::read_to_string(filepath).map_err(|_| "Cannot read plugin file")?;

let lua = Lua::new();
let lua = Lua::new_with(mlua::StdLib::ALL_SAFE, mlua::LuaOptions::default())
.map_err(|_| "Cannot create Lua obj".to_string())?;

Plugin::check_integrity(&lua, &code)?;

Expand Down Expand Up @@ -151,27 +152,32 @@ impl Plugin {
self.name.as_str()
}

fn create_lua_thread(
lua: &Lua,
code: &str,
coroutine_name: &str,
) -> Result<RegistryKey, String> {
Plugin::append_plugins_dir(lua)?;

lua.load(code)
.exec()
.map_err(|_| "Fail to load Lua code".to_string())?;

let serial_rx: Thread = lua
.load(format!("coroutine.create({})", coroutine_name))
.eval()
.map_err(|_| format!("Fail to create coroutine for {}", coroutine_name))?;
let reg = lua
.create_registry_value(serial_rx)
.map_err(|_| format!("Fail to create register for {} coroutines", coroutine_name))?;
Ok(reg)
}

pub fn serial_rx_call(&self, msg: Vec<u8>) -> SerialRxCall {
let lua = Lua::new();
let code = self.code.as_str();

let serial_rx_reg: Result<RegistryKey, String> = lua.context(move |lua_ctx| {
Plugin::append_plugins_dir(&lua_ctx)?;

lua_ctx
.load(code)
.exec()
.map_err(|_| "Fail to load Lua code".to_string())?;

let serial_rx: Thread = lua_ctx
.load(r#"coroutine.create(serial_rx)"#)
.eval()
.map_err(|_| "Fail to create coroutine for serial_rx".to_string())?;
let reg = lua_ctx
.create_registry_value(serial_rx)
.map_err(|_| "Fail to create register for serial_rx coroutines".to_string())?;
Ok(reg)
});
let lua = Lua::new_with(mlua::StdLib::ALL_SAFE, mlua::LuaOptions::default())
.expect("Cannot create Lua obj");

let serial_rx_reg = Self::create_lua_thread(&lua, self.code.as_str(), "serial_rx");

SerialRxCall {
lua,
Expand All @@ -182,26 +188,10 @@ impl Plugin {
}

pub fn user_command_call(&self, arg_list: Vec<String>) -> UserCommandCall {
let lua = Lua::new();
let code = self.code.as_str();

let user_command_reg: Result<RegistryKey, String> = lua.context(move |lua_ctx| {
Plugin::append_plugins_dir(&lua_ctx)?;

lua_ctx
.load(code)
.exec()
.map_err(|_| "Fail to load Lua code".to_string())?;

let user_command: Thread = lua_ctx
.load(r#"coroutine.create(user_command)"#)
.eval()
.map_err(|_| "Fail to create coroutine for user_command".to_string())?;
let reg = lua_ctx
.create_registry_value(user_command)
.map_err(|_| "Fail to create register for user_command coroutines".to_string())?;
Ok(reg)
});
let lua = Lua::new_with(mlua::StdLib::ALL_SAFE, mlua::LuaOptions::default())
.expect("Cannot create Lua obj");

let user_command_reg = Self::create_lua_thread(&lua, self.code.as_str(), "user_command");

UserCommandCall {
lua,
Expand All @@ -211,15 +201,15 @@ impl Plugin {
}
}

fn append_plugins_dir(lua_ctx: &Context) -> Result<(), String> {
fn append_plugins_dir(lua: &Lua) -> Result<(), String> {
let home_dir = get_my_home()
.expect("Cannot get home directory")
.expect("Cannot get home directory")
.to_str()
.expect("Cannot get home directory")
.to_string();

if lua_ctx
if lua
.load(
format!(
"package.path = package.path .. ';{}/.config/scope/plugins/?.lua'",
Expand All @@ -237,33 +227,30 @@ impl Plugin {
}

fn check_integrity(lua: &Lua, code: &str) -> Result<(), String> {
lua.context(|lua_ctx| {
let globals = lua_ctx.globals();
let globals = lua.globals();

Plugin::append_plugins_dir(&lua_ctx)?;
Plugin::append_plugins_dir(&lua)?;

lua_ctx
.load(code)
.exec()
.map_err(|_| "Fail to load Lua code".to_string())?;
lua.load(code)
.exec()
.map_err(|_| "Fail to load Lua code".to_string())?;

globals
.get::<_, Function>("serial_rx")
.map_err(|_| "serial_rx function not found in Lua code")?;
globals
.get::<_, Function>("user_command")
.map_err(|_| "user_command function not found in Lua code")?;
globals
.get::<_, Function>("serial_rx")
.map_err(|_| "serial_rx function not found in Lua code")?;
globals
.get::<_, Function>("user_command")
.map_err(|_| "user_command function not found in Lua code")?;

Ok(())
})
Ok(())
}
}

fn resume_lua_thread<T: for<'a> rlua::ToLuaMulti<'a> + Send>(
thread: &Thread,
data: T,
) -> Option<PluginRequest> {
match thread.resume::<T, Table>(data) {
fn resume_lua_thread<T>(thread: &Thread, data: T) -> Option<PluginRequest>
where
T: for<'a> mlua::IntoLuaMulti<'a>,
{
match thread.resume::<_, Table>(data) {
Ok(req) => {
let req: PluginRequest = match req.try_into() {
Ok(req) => req,
Expand All @@ -283,30 +270,29 @@ impl Iterator for SerialRxCall {
let thread = &self.thread;
let msg = self.msg.clone();

self.lua.context(move |lua_ctx| {
let serial_rx: Thread = lua_ctx
.registry_value(thread)
.expect("Cannot get serial_rx register");
let serial_rx: Thread = self
.lua
.registry_value(thread)
.expect("Cannot get serial_rx register");

let Some(req_result) = req_result else {
return resume_lua_thread(&serial_rx, msg);
};
let Some(req_result) = req_result else {
return resume_lua_thread(&serial_rx, msg);
};

match req_result {
PluginRequestResult::Exec { stdout, stderr } => {
match serial_rx.resume::<_, Table>((msg, stdout, stderr)) {
Ok(req) => {
let req: PluginRequest = match req.try_into() {
Ok(req) => req,
Err(msg) => return Some(PluginRequest::Eprintln { msg }),
};
Some(req)
}
Err(_) => None,
match req_result {
PluginRequestResult::Exec { stdout, stderr } => {
match serial_rx.resume::<_, Table>((msg, stdout, stderr)) {
Ok(req) => {
let req: PluginRequest = match req.try_into() {
Ok(req) => req,
Err(msg) => return Some(PluginRequest::Eprintln { msg }),
};
Some(req)
}
Err(_) => None,
}
}
})
}
}
}

Expand All @@ -318,30 +304,29 @@ impl Iterator for UserCommandCall {
let thread = &self.thread;
let arg_list = self.arg_list.clone();

self.lua.context(move |lua_ctx| {
let user_command: Thread = lua_ctx
.registry_value(thread)
.expect("Cannot get user_command register");
let user_command: Thread = self
.lua
.registry_value(thread)
.expect("Cannot get user_command register");

let Some(req_result) = req_result else {
return resume_lua_thread(&user_command, arg_list);
};
let Some(req_result) = req_result else {
return resume_lua_thread(&user_command, arg_list);
};

match req_result {
PluginRequestResult::Exec { stdout, stderr } => {
match user_command.resume::<_, Table>((arg_list, stdout, stderr)) {
Ok(req) => {
let req: PluginRequest = match req.try_into() {
Ok(req) => req,
Err(msg) => return Some(PluginRequest::Eprintln { msg }),
};
Some(req)
}
Err(_) => None,
match req_result {
PluginRequestResult::Exec { stdout, stderr } => {
match user_command.resume::<_, Table>((arg_list, stdout, stderr)) {
Ok(req) => {
let req: PluginRequest = match req.try_into() {
Ok(req) => req,
Err(msg) => return Some(PluginRequest::Eprintln { msg }),
};
Some(req)
}
Err(_) => None,
}
}
})
}
}
}

Expand Down
Loading