Skip to content

Commit

Permalink
refactor nvs, add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
indexds committed Dec 27, 2024
1 parent 5a823c7 commit da876fe
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 114 deletions.
8 changes: 4 additions & 4 deletions src/http/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ pub fn index_html(nvs: &NvsWireguard) -> anyhow::Result<String> {
<script src="index.js"></script>
</html>
"###,
nvs.wg_addr.clean_string().as_str(),
nvs.wg_port.clean_string().as_str(),
nvs.wg_cli_pri.clean_string().as_str(),
nvs.wg_serv_pub.clean_string().as_str(),
nvs.address.clean_string().as_str(),
nvs.port.clean_string().as_str(),
nvs.client_private_key.clean_string().as_str(),
nvs.server_public_key.clean_string().as_str(),
))
}
4 changes: 1 addition & 3 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ pub fn start_http_server(
http_server.fn_handler("/", Method::Get, {
let nvs = Arc::clone(&nvs);
move |mut request| {
let nvs = nvs.lock().unwrap();

let wg_conf = NvsWireguard::new(&nvs)?;
let wg_conf = NvsWireguard::new(Arc::clone(&nvs))?;

let html = index::index_html(&wg_conf)?;

Expand Down
28 changes: 8 additions & 20 deletions src/http/wg_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use esp_idf_svc::http::server::{EspHttpServer, Method};
use esp_idf_svc::nvs::{EspNvs, NvsDefault};
use esp_idf_svc::wifi::EspWifi;

use crate::utils::nvs::{NvsKeys, NvsWireguard};
use crate::utils::nvs::NvsWireguard;
use crate::wireguard;
use crate::wireguard::ctx::WG_CTX;

Expand All @@ -17,7 +17,6 @@ pub fn set_routes(
) -> anyhow::Result<()> {
// Handler to connect to a wireguard peer
http_server.fn_handler("/connect-wg", Method::Post, {
let nvs_set = Arc::clone(&nvs);
let wifi_check = Arc::clone(&wifi);

// This is so fucking stupid but we can't do otherwise
Expand Down Expand Up @@ -47,14 +46,7 @@ pub fn set_routes(

let wg_conf: NvsWireguard = serde_urlencoded::from_str(String::from_utf8(body)?.as_str())?;

let mut nvs_set = nvs_set.lock().unwrap();

NvsWireguard::set_field(&mut nvs_set, NvsKeys::WG_ADDR, wg_conf.wg_addr.clean_string().as_str())?;
NvsWireguard::set_field(&mut nvs_set, NvsKeys::WG_PORT, wg_conf.wg_port.clean_string().as_str())?;
NvsWireguard::set_field(&mut nvs_set, NvsKeys::WG_CLI_PRI, wg_conf.wg_cli_pri.clean_string().as_str())?;
NvsWireguard::set_field(&mut nvs_set, NvsKeys::WG_SERV_PUB, wg_conf.wg_serv_pub.clean_string().as_str())?;

drop(nvs_set);
NvsWireguard::set_fields(Arc::clone(&nvs), wg_conf)?;

// Yeah..
let wifi = Arc::clone(&wifi);
Expand Down Expand Up @@ -107,11 +99,10 @@ pub fn set_routes(
None => "disconnected",
};

let nvs = nvs.lock().unwrap();
let nvs = NvsWireguard::new(&nvs)?;
let nvs = NvsWireguard::new(Arc::clone(&nvs))?;

let status = match *ctx {
Some(_) => nvs.wg_addr.as_str(),
Some(_) => nvs.address.as_str(),
None => "Disconnected",
};

Expand All @@ -128,15 +119,12 @@ pub fn set_routes(
.as_str(),
);

match *ctx {
Some(_) => {
html.push_str(
r###"
if (*ctx).is_some() {
html.push_str(
r###"
<button id="disconnect-wg-button" onclick="disconnectWg()">Disconnect</button>
"###,
);
}
None => {}
);
};

connection.write(html.as_bytes())?;
Expand Down
11 changes: 2 additions & 9 deletions src/http/wifi_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use esp_idf_svc::nvs::{EspNvs, NvsDefault};
use esp_idf_svc::wifi::{AuthMethod, EspWifi};

use crate::network::wifi;
use crate::utils::nvs::{NvsKeys, NvsWifi};
use crate::utils::nvs::NvsWifi;

pub fn set_routes(
http_server: &mut EspHttpServer<'static>,
Expand All @@ -31,7 +31,6 @@ pub fn set_routes(

// Handler to connect to wifi
http_server.fn_handler("/connect-wifi", Method::Post, {
let nvs_set = Arc::clone(&nvs);
let wifi = Arc::clone(&wifi);

move |mut request| {
Expand All @@ -48,13 +47,7 @@ pub fn set_routes(

let wifi_conf: NvsWifi = serde_urlencoded::from_str(String::from_utf8(body)?.as_str())?;

let mut nvs_set = nvs_set.lock().unwrap();

NvsWifi::set_field(&mut nvs_set, NvsKeys::STA_SSID, wifi_conf.sta_ssid.clean_string().as_str())?;
NvsWifi::set_field(&mut nvs_set, NvsKeys::STA_PASSWD, wifi_conf.sta_passwd.clean_string().as_str())?;
NvsWifi::set_field(&mut nvs_set, NvsKeys::STA_AUTH, wifi_conf.sta_auth.clean_string().as_str())?;

drop(nvs_set);
NvsWifi::set_fields(Arc::clone(&nvs), wifi_conf)?;

let nvs_thread = Arc::clone(&nvs);
let wifi = Arc::clone(&wifi);
Expand Down
14 changes: 7 additions & 7 deletions src/network/wifi.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::str::FromStr;
use std::sync::{Arc, Mutex};

use esp_idf_svc::eventloop::EspSystemEventLoop;
use esp_idf_svc::hal::modem::Modem;
use esp_idf_svc::netif::{EspNetif, NetifConfiguration, NetifStack};
use esp_idf_svc::nvs::{EspDefaultNvsPartition, EspNvs, NvsDefault};
use esp_idf_svc::wifi::{AuthMethod, ClientConfiguration, Configuration, EspWifi, WifiDriver};
use esp_idf_svc::wifi::{ClientConfiguration, Configuration, EspWifi, WifiDriver};

use crate::utils::nvs::{NvsKeys, NvsWifi};
use crate::utils::nvs::NvsWifi;

pub fn init_netif(
modem: Modem,
Expand Down Expand Up @@ -38,12 +37,13 @@ pub fn set_configuration(
log::info!("Setting wifi configuration...");

let mut wifi = wifi.lock().unwrap();
let nvs = nvs.lock().unwrap();

let nvs = NvsWifi::new(Arc::clone(&nvs))?;

let wifi_config = Configuration::Client(ClientConfiguration {
ssid: NvsWifi::get_field::<32>(&nvs, NvsKeys::STA_SSID)?,
password: NvsWifi::get_field::<64>(&nvs, NvsKeys::STA_PASSWD)?,
auth_method: AuthMethod::from_str(NvsWifi::get_field::<32>(&nvs, NvsKeys::STA_AUTH)?.as_str())?,
ssid: nvs.sta_ssid.0,
password: nvs.sta_passwd.0,
auth_method: nvs.sta_auth.as_str().try_into()?,
..Default::default()
});

Expand Down
7 changes: 2 additions & 5 deletions src/utils/heapless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ffi::CString;
use heapless::String;
use serde::Deserialize;

//This wrapper is necessary to juggle wifi stuff
#[derive(Deserialize, Default)]
pub struct HeaplessString<const N: usize>(pub String<N>);

Expand All @@ -11,10 +12,6 @@ impl<const N: usize> HeaplessString<N> {
Self(String::<N>::new())
}

pub fn inner(&self) -> String<N> {
self.0.clone()
}

pub fn push_str(&mut self, s: &str) -> anyhow::Result<()> {
if s.len() > N {
return Err(anyhow::anyhow!("String too long."));
Expand All @@ -29,7 +26,7 @@ impl<const N: usize> HeaplessString<N> {
self.0.as_str()
}

pub fn chars(&self) -> std::str::Chars<'_> {
fn chars(&self) -> std::str::Chars<'_> {
self.0.chars()
}

Expand Down
126 changes: 69 additions & 57 deletions src/utils/nvs.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,37 @@
#![allow(unused)]

use std::sync::MutexGuard;
use std::sync::{Arc, Mutex, MutexGuard};

use esp_idf_svc::nvs::{EspNvs, NvsDefault};
use heapless::String;
use serde::Deserialize;

use crate::utils::heapless::HeaplessString;

const DEFAULT_STA_SSID: &str = "";
const DEFAULT_STA_PASSWD: &str = "";
const DEFAULT_STA_AUTH: &str = "wpa2personal";

const DEFAULT_WG_ADDR: &str = "";
const DEFAULT_WG_PORT: &str = "51820";

const DEFAULT_WG_CLIENT_PRI: &str = "";
const DEFAULT_WG_SERVER_PUB: &str = "";

pub struct NvsKeys;

impl NvsKeys {
pub const STA_AUTH: &'static str = "AUTH";
pub const STA_PASSWD: &'static str = "PASSWD";
pub const STA_SSID: &'static str = "SSID";
pub const WG_ADDR: &'static str = "ADDR";
pub const WG_CLI_PRI: &'static str = "PRIVKEY";
pub const WG_PORT: &'static str = "PORT";
pub const WG_SERV_PUB: &'static str = "PUBKEY";
}

#[derive(Deserialize)]
pub struct NvsWireguard {
#[serde(rename = "address")]
pub wg_addr: HeaplessString<32>,
pub address: HeaplessString<32>,

#[serde(rename = "port")]
pub wg_port: HeaplessString<16>,
pub port: HeaplessString<16>,

#[serde(rename = "privkey")]
pub wg_cli_pri: HeaplessString<64>,
pub client_private_key: HeaplessString<64>,

#[serde(rename = "pubkey")]
pub wg_serv_pub: HeaplessString<64>,
pub server_public_key: HeaplessString<64>,
}

impl NvsWireguard {
pub fn get_field<const N: usize>(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>, key: &str) -> anyhow::Result<String<N>> {
const ADDR: &'static str = "ADDR";
const CLIENT_PRIV: &'static str = "PRIVKEY";
const DEFAULT_ADDR: &str = "";
const DEFAULT_CLIENT_PRIV: &str = "";
const DEFAULT_PORT: &str = "51820";
const DEFAULT_SERVER_PUB: &str = "";
const PORT: &'static str = "PORT";
const SERVER_PUB: &'static str = "PUBKEY";

fn get_field<const N: usize>(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>, key: &str) -> anyhow::Result<String<N>> {
let mut buf = [0u8; N];
nvs.get_str(key, &mut buf)?;

Expand All @@ -57,35 +42,45 @@ impl NvsWireguard {
let mut value = HeaplessString::<N>::new();
value.push_str(raw_value)?;

Ok(value.clean_string().inner())
Ok(value.clean_string().0)
}

pub fn set_field(nvs: &mut MutexGuard<'_, EspNvs<NvsDefault>>, key: &str, value: &str) -> anyhow::Result<()> {
nvs.set_str(key, value.trim())?;
/// Call to set the Wireguard configuration in nvs.
pub fn set_fields(nvs: Arc<Mutex<EspNvs<NvsDefault>>>, keys: NvsWireguard) -> anyhow::Result<()> {
let mut nvs = nvs.lock().unwrap();

nvs.set_str(Self::ADDR, keys.address.clean_string().as_str())?;
nvs.set_str(Self::PORT, keys.port.clean_string().as_str())?;
nvs.set_str(Self::CLIENT_PRIV, keys.client_private_key.clean_string().as_str())?;
nvs.set_str(Self::SERVER_PUB, keys.server_public_key.clean_string().as_str())?;

Ok(())
}

pub fn new(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>) -> anyhow::Result<Self> {
/// Call to get an instance of NvsWireguard containing the current stored
/// Wireguard configs.
pub fn new(nvs: Arc<Mutex<EspNvs<NvsDefault>>>) -> anyhow::Result<Self> {
let nvs = nvs.lock().unwrap();

Ok(Self {
wg_addr: HeaplessString(
NvsWireguard::get_field::<32>(nvs, NvsKeys::WG_ADDR)
.unwrap_or_else(|_| DEFAULT_WG_ADDR.try_into().unwrap()),
address: HeaplessString(
NvsWireguard::get_field::<32>(&nvs, Self::ADDR)
.unwrap_or_else(|_| Self::DEFAULT_ADDR.try_into().unwrap()),
),

wg_port: HeaplessString(
NvsWireguard::get_field::<16>(nvs, NvsKeys::WG_PORT)
.unwrap_or_else(|_| DEFAULT_WG_PORT.try_into().unwrap()),
port: HeaplessString(
NvsWireguard::get_field::<16>(&nvs, Self::PORT)
.unwrap_or_else(|_| Self::DEFAULT_PORT.try_into().unwrap()),
),

wg_cli_pri: HeaplessString(
NvsWireguard::get_field::<64>(nvs, NvsKeys::WG_CLI_PRI)
.unwrap_or_else(|_| DEFAULT_WG_CLIENT_PRI.try_into().unwrap()),
client_private_key: HeaplessString(
NvsWireguard::get_field::<64>(&nvs, Self::CLIENT_PRIV)
.unwrap_or_else(|_| Self::DEFAULT_CLIENT_PRIV.try_into().unwrap()),
),

wg_serv_pub: HeaplessString(
NvsWireguard::get_field::<64>(nvs, NvsKeys::WG_SERV_PUB)
.unwrap_or_else(|_| DEFAULT_WG_SERVER_PUB.try_into().unwrap()),
server_public_key: HeaplessString(
NvsWireguard::get_field::<64>(&nvs, Self::SERVER_PUB)
.unwrap_or_else(|_| Self::DEFAULT_SERVER_PUB.try_into().unwrap()),
),
})
}
Expand All @@ -104,7 +99,14 @@ pub struct NvsWifi {
}

impl NvsWifi {
pub fn get_field<const N: usize>(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>, key: &str) -> anyhow::Result<String<N>> {
const DEFAULT_STA_AUTH: &str = "wpa2personal";
const DEFAULT_STA_PASSWD: &str = "";
const DEFAULT_STA_SSID: &str = "";
const STA_AUTH: &'static str = "AUTH";
const STA_PASSWD: &'static str = "PASSWD";
const STA_SSID: &'static str = "SSID";

fn get_field<const N: usize>(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>, key: &str) -> anyhow::Result<String<N>> {
let mut buf = [0u8; N];
nvs.get_str(key, &mut buf)?;

Expand All @@ -115,30 +117,40 @@ impl NvsWifi {
let mut value = HeaplessString::<N>::new();
value.push_str(raw_value)?;

Ok(value.clean_string().inner())
Ok(value.clean_string().0)
}

pub fn set_field(nvs: &mut MutexGuard<'_, EspNvs<NvsDefault>>, key: &str, value: &str) -> anyhow::Result<()> {
nvs.set_str(key, value.trim())?;
/// Call to set the wifi configuration in nvs.
pub fn set_fields(nvs: Arc<Mutex<EspNvs<NvsDefault>>>, keys: NvsWifi) -> anyhow::Result<()> {
let mut nvs = nvs.lock().unwrap();

nvs.set_str(Self::STA_SSID, keys.sta_ssid.clean_string().as_str())?;
nvs.set_str(Self::STA_PASSWD, keys.sta_passwd.clean_string().as_str())?;
nvs.set_str(Self::STA_AUTH, keys.sta_auth.clean_string().as_str())?;

Ok(())
}

pub fn new(nvs: &MutexGuard<'_, EspNvs<NvsDefault>>) -> anyhow::Result<Self> {
/// Call to get an instance of NvsWifi containing the current stored wifi
/// configs.
pub fn new(nvs: Arc<Mutex<EspNvs<NvsDefault>>>) -> anyhow::Result<Self> {
let nvs = nvs.lock().unwrap();

// These cannot fail, so we don't care about the unwraps
Ok(Self {
sta_ssid: HeaplessString(
NvsWifi::get_field::<32>(nvs, NvsKeys::STA_SSID)
.unwrap_or_else(|_| DEFAULT_STA_SSID.try_into().unwrap()),
NvsWifi::get_field::<32>(&nvs, Self::STA_SSID)
.unwrap_or_else(|_| Self::DEFAULT_STA_SSID.try_into().unwrap()),
),

sta_passwd: HeaplessString(
NvsWifi::get_field::<64>(nvs, NvsKeys::STA_PASSWD)
.unwrap_or_else(|_| DEFAULT_STA_PASSWD.try_into().unwrap()),
NvsWifi::get_field::<64>(&nvs, Self::STA_PASSWD)
.unwrap_or_else(|_| Self::DEFAULT_STA_PASSWD.try_into().unwrap()),
),

sta_auth: HeaplessString(
NvsWifi::get_field::<32>(nvs, NvsKeys::STA_AUTH)
.unwrap_or_else(|_| DEFAULT_STA_AUTH.try_into().unwrap()),
NvsWifi::get_field::<32>(&nvs, Self::STA_AUTH)
.unwrap_or_else(|_| Self::DEFAULT_STA_AUTH.try_into().unwrap()),
),
})
}
Expand Down
1 change: 1 addition & 0 deletions src/wireguard/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ impl WireguardCtx {
}
}

// Global hot potato that needs to never ever be dropped
lazy_static::lazy_static!(
pub static ref WG_CTX: Arc<Mutex<Option<WireguardCtx>>> = Arc::new(Mutex::new(None));
);
Loading

0 comments on commit da876fe

Please sign in to comment.