Skip to content
This repository has been archived by the owner on Feb 11, 2024. It is now read-only.

Commit

Permalink
fix: remove authentication for get queries (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xavier Basty authored May 5, 2023
1 parent 4a63f3a commit 34c0b55
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 68 deletions.
6 changes: 0 additions & 6 deletions src/handlers/get_messages.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use {
crate::{
auth::AuthBearer,
error,
increment_counter,
increment_counter_with,
Expand All @@ -11,7 +10,6 @@ use {
extract::{Query, State},
Json,
},
relay_rpc::auth::Jwt,
serde::{Deserialize, Serialize},
std::{cmp, sync::Arc},
};
Expand Down Expand Up @@ -73,18 +71,15 @@ pub struct GetMessagesResponse {
/// The handler for the get messages endpoint.
pub async fn handler(
State(state): State<Arc<AppState>>,
AuthBearer(token): AuthBearer,
query: Query<GetMessagesBody>,
) -> Result<Json<GetMessagesResponse>, error::Error> {
let client_id = Jwt(token).decode(&state.auth_aud.clone())?;
let direction = query.direction.unwrap_or(Direction::Forward);

let StoreMessages { messages, next_id } = match (&query.origin_id, direction) {
(origin_id, Direction::Forward) => {
state
.messages_store
.get_messages_after(
client_id.value(),
query.topic.as_ref(),
origin_id.as_deref(),
query.message_count.limit(),
Expand All @@ -95,7 +90,6 @@ pub async fn handler(
state
.messages_store
.get_messages_before(
client_id.value(),
query.topic.as_ref(),
origin_id.as_deref(),
query.message_count.limit(),
Expand Down
3 changes: 1 addition & 2 deletions src/relay/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<S, B, T> FromRequest<S, B> for RequireValidSignature<T>
where
// these bounds are required by
// `async_trait`
B: Send + 'static + body::HttpBody + From<hyper::body::Bytes>,
B: Send + 'static + body::HttpBody + From<body::Bytes>,
B::Data: Send,
S: Send + Sync + State,
T: FromRequest<S, B>,
Expand Down Expand Up @@ -82,7 +82,6 @@ where
(Some(_), None) => Err(MissingTimestampHeader),
(None, Some(_)) => Err(MissingSignatureHeader),
(None, None) => Err(MissingAllSignatureHeader),
_ => Err(MissingAllSignatureHeader),
}
}
}
Expand Down
8 changes: 3 additions & 5 deletions src/store/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use {
super::StoreError,
async_trait::async_trait,
serde::{Deserialize, Serialize},
std::sync::Arc,
std::{fmt::Debug, sync::Arc},
wither::{
bson::{self, doc, oid::ObjectId},
Model,
Expand All @@ -14,7 +14,7 @@ use {
collection_name = "Messages",
index(keys = r#"doc!{"ts": 1}"#),
index(keys = r#"doc!{"ts": -1}"#),
index(keys = r#"doc!{"client_id": 1, "topic": 1}"#),
index(keys = r#"doc!{"topic": 1}"#),
index(
keys = r#"doc!{"client_id": 1, "topic": 1, "message_id": 1}"#,
options = r#"doc!{"unique": true}"#
Expand Down Expand Up @@ -46,7 +46,7 @@ pub struct StoreMessages {
}

#[async_trait]
pub trait MessagesStore: 'static + std::fmt::Debug + Send + Sync {
pub trait MessagesStore: 'static + Send + Sync {
async fn upsert_message(
&self,
method: &str,
Expand All @@ -57,14 +57,12 @@ pub trait MessagesStore: 'static + std::fmt::Debug + Send + Sync {
) -> Result<(), StoreError>;
async fn get_messages_after(
&self,
client_id: &str,
topic: &str,
origin: Option<&str>,
message_count: usize,
) -> Result<StoreMessages, StoreError>;
async fn get_messages_before(
&self,
client_id: &str,
topic: &str,
origin: Option<&str>,
message_count: usize,
Expand Down
11 changes: 3 additions & 8 deletions src/store/mongo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use {
},
};

#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct MongoStore {
db: Database,
}
Expand Down Expand Up @@ -64,7 +64,6 @@ impl MongoStore {

async fn get_messages(
&self,
client_id: &str,
topic: &str,
origin: Option<&str>,
message_count: usize,
Expand All @@ -73,13 +72,11 @@ impl MongoStore {
) -> Result<StoreMessages, StoreError> {
let filter: Result<Document, StoreError> = match origin {
None => Ok(doc! {
"client_id": &client_id,
"topic": &topic,
}),
Some(origin) => {
let ts = self.get_message_timestamp(topic, origin).await?;
Ok(doc! {
"client_id": &client_id,
"topic": &topic,
"ts": { comparator: ts }
})
Expand Down Expand Up @@ -147,23 +144,21 @@ impl MessagesStore for MongoStore {

async fn get_messages_after(
&self,
client_id: &str,
topic: &str,
origin: Option<&str>,
message_count: usize,
) -> Result<StoreMessages, StoreError> {
self.get_messages(client_id, topic, origin, message_count, "$gte", 1)
self.get_messages(topic, origin, message_count, "$gte", 1)
.await
}

async fn get_messages_before(
&self,
client_id: &str,
topic: &str,
origin: Option<&str>,
message_count: usize,
) -> Result<StoreMessages, StoreError> {
self.get_messages(client_id, topic, origin, message_count, "$lte", -1)
self.get_messages(topic, origin, message_count, "$lte", -1)
.await
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/store/registrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct Registration {
}

#[async_trait]
pub trait RegistrationStore: 'static + std::fmt::Debug + Send + Sync {
pub trait RegistrationStore: 'static + Send + Sync {
async fn upsert_registration(
&self,
client_id: &str,
Expand Down
38 changes: 8 additions & 30 deletions tests/storage/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn test_after_no_origin(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_after(TEST_CLIENT_ID, topic, None, TEST_QUERY_SIZE)
.get_messages_after(topic, None, TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -58,12 +58,7 @@ async fn test_after_origin(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_after(
TEST_CLIENT_ID,
topic,
Some(&origin.to_string()),
TEST_QUERY_SIZE,
)
.get_messages_after(topic, Some(&origin.to_string()), TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -104,12 +99,7 @@ async fn test_after_origin_overflow(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_after(
TEST_CLIENT_ID,
topic,
Some(&origin.to_string()),
TEST_QUERY_SIZE,
)
.get_messages_after(topic, Some(&origin.to_string()), TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -140,7 +130,7 @@ async fn test_before_no_origin(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_before(TEST_CLIENT_ID, topic, None, TEST_QUERY_SIZE)
.get_messages_before(topic, None, TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -177,12 +167,7 @@ async fn test_before_origin(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_before(
TEST_CLIENT_ID,
topic,
Some(&origin.to_string()),
TEST_QUERY_SIZE,
)
.get_messages_before(topic, Some(&origin.to_string()), TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -223,12 +208,7 @@ async fn test_before_origin_overflow(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_before(
TEST_CLIENT_ID,
topic,
Some(&origin.to_string()),
TEST_QUERY_SIZE,
)
.get_messages_before(topic, Some(&origin.to_string()), TEST_QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -267,7 +247,7 @@ async fn test_multi_topic(ctx: &StoreContext) {
let result = ctx
.storage
.store
.get_messages_after(TEST_CLIENT_ID, topic.as_str(), None, QUERY_SIZE)
.get_messages_after(topic.as_str(), None, QUERY_SIZE)
.await
.unwrap();

Expand Down Expand Up @@ -309,12 +289,10 @@ async fn test_multi_clients(ctx: &StoreContext) {
}

for t in 0..NUM_CLIENTS {
let client_id = format!("{}-{}", TEST_CLIENT_ID, t + 1);

let result = ctx
.storage
.store
.get_messages_after(client_id.as_str(), topic, None, QUERY_SIZE)
.get_messages_after(topic, None, QUERY_SIZE)
.await
.unwrap();

Expand Down
16 changes: 0 additions & 16 deletions tests/storage/mocks/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,10 @@ impl MessagesStore for MockMessageStore {

async fn get_messages_after(
&self,
client_id: &str,
_topic: &str,
_origin: Option<&str>,
_message_count: usize,
) -> Result<StoreMessages, StoreError> {
if self.client_id.is_some() && self.client_id != Some(client_id.to_string()) {
return Err(StoreError::NotFound(
"messages".to_string(),
client_id.to_string(),
));
}

Ok(StoreMessages {
messages: self.test_get_messages(),
next_id: Some(Arc::from("after")),
Expand All @@ -97,18 +89,10 @@ impl MessagesStore for MockMessageStore {

async fn get_messages_before(
&self,
client_id: &str,
_topic: &str,
_origin: Option<&str>,
_message_count: usize,
) -> Result<StoreMessages, StoreError> {
if self.client_id.is_some() && self.client_id != Some(client_id.to_string()) {
return Err(StoreError::NotFound(
"messages".to_string(),
client_id.to_string(),
));
}

Ok(StoreMessages {
messages: self.test_get_messages(),
next_id: Some(Arc::from("before")),
Expand Down

0 comments on commit 34c0b55

Please sign in to comment.