Skip to content

Commit

Permalink
Refactor backend handlers and update pgwire dependency. (#258)
Browse files Browse the repository at this point in the history
Replaces custom handler implementations with a streamlined `PgWireServerHandlers` factory. Updates `pgwire` to version 0.28.0. This improves modularity and compatibility, simplifying query processing and error handling.
  • Loading branch information
loloxwg authored Dec 31, 2024
1 parent 468928b commit 85cf196
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ clap = { version = "4.5", features = ["derive"], optional = tru
env_logger = { version = "0.11", optional = true }
futures = { version = "0.3", optional = true }
log = { version = "0.4", optional = true }
pgwire = { version = "0.19", optional = true }
pgwire = { version = "0.28.0", optional = true }
tokio = { version = "1.36", features = ["full"], optional = true }


Expand Down
124 changes: 73 additions & 51 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@ use clap::Parser;
use fnck_sql::db::{DBTransaction, DataBaseBuilder, Database, ResultIter};
use fnck_sql::errors::DatabaseError;
use fnck_sql::storage::rocksdb::RocksStorage;
use fnck_sql::types::tuple::{Schema, Tuple};
use fnck_sql::types::tuple::{Schema, SchemaRef, Tuple};
use fnck_sql::types::LogicalType;
use futures::stream;
use log::{error, info, LevelFilter};
use parking_lot::Mutex;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::auth::StartupHandler;
use pgwire::api::query::{
ExtendedQueryHandler, PlaceholderExtendedQueryHandler, SimpleQueryHandler,
};
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
use pgwire::api::MakeHandler;
use pgwire::api::{ClientInfo, StatelessMakeHandler, Type};
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::tokio::process_socket;
use std::fmt::Debug;
Expand Down Expand Up @@ -83,29 +80,67 @@ pub struct FnckSQLBackend {
inner: Arc<Database<RocksStorage>>,
}

impl FnckSQLBackend {
pub fn new(path: impl Into<PathBuf> + Send) -> Result<FnckSQLBackend, DatabaseError> {
let database = DataBaseBuilder::path(path).build()?;

Ok(FnckSQLBackend {
inner: Arc::new(database),
})
}
}

pub struct SessionBackend {
inner: Arc<Database<RocksStorage>>,
tx: Mutex<Option<TransactionPtr>>,
}

impl MakeHandler for FnckSQLBackend {
type Handler = Arc<SessionBackend>;

fn make(&self) -> Self::Handler {
Arc::new(SessionBackend {
inner: Arc::clone(&self.inner),
impl SessionBackend {
pub fn new(inner: Arc<Database<RocksStorage>>) -> SessionBackend {
SessionBackend {
inner,
tx: Mutex::new(None),
})
}
}
}

impl FnckSQLBackend {
pub fn new(path: impl Into<PathBuf> + Send) -> Result<FnckSQLBackend, DatabaseError> {
let database = DataBaseBuilder::path(path).build()?;
impl NoopStartupHandler for SessionBackend {}

Ok(FnckSQLBackend {
inner: Arc::new(database),
})
struct CustomBackendFactory {
handler: Arc<SessionBackend>,
}

impl CustomBackendFactory {
pub fn new(handler: Arc<SessionBackend>) -> CustomBackendFactory {
CustomBackendFactory { handler }
}
}

impl PgWireServerHandlers for CustomBackendFactory {
type StartupHandler = SessionBackend;
type SimpleQueryHandler = SessionBackend;
type ExtendedQueryHandler = PlaceholderExtendedQueryHandler;
type CopyHandler = NoopCopyHandler;
type ErrorHandler = NoopErrorHandler;

fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.handler.clone()
}

fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
Arc::new(PlaceholderExtendedQueryHandler)
}

fn startup_handler(&self) -> Arc<Self::StartupHandler> {
self.handler.clone()
}

fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}

fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(NoopErrorHandler)
}
}

Expand Down Expand Up @@ -179,7 +214,10 @@ impl SimpleQueryHandler for SessionBackend {
for tuple in iter.by_ref() {
tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?);
}
encode_tuples(iter.schema(), tuples)?
let schema = iter.schema().clone();
iter.done()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
encode_tuples(&schema, tuples)?
} else {
let mut iter = self
.inner
Expand All @@ -188,15 +226,18 @@ impl SimpleQueryHandler for SessionBackend {
for tuple in iter.by_ref() {
tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?);
}
encode_tuples(iter.schema(), tuples)?
let schema = iter.schema().clone();
iter.done()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
encode_tuples(&schema, tuples)?
};
Ok(vec![Response::Query(response)])
}
}
}
}

fn encode_tuples<'a>(schema: &Schema, tuples: Vec<Tuple>) -> PgWireResult<QueryResponse<'a>> {
fn encode_tuples<'a>(schema: &SchemaRef, tuples: Vec<Tuple>) -> PgWireResult<QueryResponse<'a>> {
if tuples.is_empty() {
return Ok(QueryResponse::new(Arc::new(vec![]), stream::empty()));
}
Expand Down Expand Up @@ -268,7 +309,7 @@ fn into_pg_type(data_type: &LogicalType) -> PgWireResult<Type> {
LogicalType::Date | LogicalType::DateTime => Type::DATE,
LogicalType::Char(..) => Type::CHAR,
LogicalType::Time => Type::TIME,
LogicalType::Decimal(_, _) => Type::FLOAT8,
LogicalType::Decimal(_, _) => Type::NUMERIC,
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
Expand Down Expand Up @@ -318,17 +359,14 @@ async fn main() {
);

let backend = FnckSQLBackend::new(args.path).unwrap();
let processor = Arc::new(backend);
// We have not implemented extended query in this server, use placeholder instead
let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new(
PlaceholderExtendedQueryHandler,
)));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
let factory = Arc::new(CustomBackendFactory::new(Arc::new(SessionBackend::new(
backend.inner,
))));
let server_addr = format!("{}:{}", args.ip, args.port);
let listener = TcpListener::bind(server_addr).await.unwrap();

tokio::select! {
res = server_run(processor, placeholder, authenticator, listener) => {
res = server_run(listener,factory) => {
if let Err(err) = res {
error!("[Listener][Failed To Accept]: {}", err);
}
Expand All @@ -337,32 +375,16 @@ async fn main() {
}
}

async fn server_run<
A: MakeHandler<Handler = Arc<impl StartupHandler + 'static>>,
Q: MakeHandler<Handler = Arc<impl SimpleQueryHandler + 'static>>,
EQ: MakeHandler<Handler = Arc<impl ExtendedQueryHandler + 'static>>,
>(
processor: Arc<Q>,
placeholder: Arc<EQ>,
authenticator: Arc<A>,
async fn server_run(
listener: TcpListener,
factory_ref: Arc<CustomBackendFactory>,
) -> io::Result<()> {
loop {
let incoming_socket = listener.accept().await?;
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
let placeholder_ref = placeholder.make();
let factory_ref = factory_ref.clone();

tokio::spawn(async move {
if let Err(err) = process_socket(
incoming_socket.0,
None,
authenticator_ref,
processor_ref,
placeholder_ref,
)
.await
{
if let Err(err) = process_socket(incoming_socket.0, None, factory_ref).await {
error!("Failed To Process: {}", err);
}
});
Expand Down

0 comments on commit 85cf196

Please sign in to comment.