Skip to content

Commit

Permalink
Add source search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikWin committed Jan 14, 2025
1 parent b65913d commit c05a6b8
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 3 deletions.
3 changes: 3 additions & 0 deletions vidformer-igni/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ async fn igni_http_req_api(
let source_id = r.unwrap().captures(&uri).unwrap().get(1).unwrap().as_str();
api::get_source(req, global, source_id, &user_auth).await
}
(hyper::Method::POST, "/v2/source/search") => {
api::search_source(req, global, &user_auth).await
}
(hyper::Method::POST, "/v2/source") // /v2/source
=> {
api::push_source(req, global, &user_auth).await
Expand Down
53 changes: 53 additions & 0 deletions vidformer-igni/src/server/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,59 @@ pub(crate) async fn get_source(
)))?)
}

pub(crate) async fn search_source(
req: hyper::Request<impl hyper::body::Body>,
global: std::sync::Arc<IgniServerGlobal>,
user: &super::UserAuth,
) -> Result<hyper::Response<http_body_util::Full<hyper::body::Bytes>>, IgniError> {
#[derive(serde::Deserialize)]
struct Request {
name: String,
stream_idx: i32,
storage_service: String,
storage_config: serde_json::Value,
}

let req: Request = match req.collect().await {
Err(_err) => {
error!("Error reading request body");
return Ok(hyper::Response::builder()
.status(hyper::StatusCode::BAD_REQUEST)
.body(http_body_util::Full::new(hyper::body::Bytes::from(
"Error reading request body",
)))?);
}
Ok(req) => match serde_json::from_slice(&req.to_bytes().to_vec()) {
Err(err) => {
error!("Error parsing request body");
return Ok(hyper::Response::builder()
.status(hyper::StatusCode::BAD_REQUEST)
.body(http_body_util::Full::new(hyper::body::Bytes::from(
format!("Bad request: {}", err),
)))?);
}
Ok(req) => req,
},
};

let rows: Vec<(Uuid,)> = sqlx::query_as("SELECT id FROM source WHERE name = $1 AND stream_idx = $2 AND storage_service = $3 AND storage_config = $4 AND user_id = $5")
.bind(req.name)
.bind(req.stream_idx)
.bind(req.storage_service)
.bind(req.storage_config)
.bind(user.user_id)
.fetch_all(&global.pool)
.await?;

let res: Vec<String> = rows.iter().map(|(id,)| id.to_string()).collect();

Ok(hyper::Response::builder()
.header("Content-Type", "application/json")
.body(http_body_util::Full::new(hyper::body::Bytes::from(
serde_json::to_string(&res).unwrap(),
)))?)
}

pub(crate) async fn delete_source(
_req: hyper::Request<impl hyper::body::Body>,
global: std::sync::Arc<IgniServerGlobal>,
Expand Down
10 changes: 8 additions & 2 deletions vidformer-py/vidformer/cv2/vf_cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,14 @@ def _explicit_terminate(self):

def write(self, frame):
frame = frameify(frame, "frame")
assert frame._fmt["width"] == self._spec._fmt["width"]
assert frame._fmt["height"] == self._spec._fmt["height"]
if frame._fmt["width"] != self._spec._fmt["width"]:
raise Exception(
f"Frame type error; expected width {self._spec._fmt['width']}, got {frame._fmt['width']}"
)
if frame._fmt["height"] != self._spec._fmt["height"]:
raise Exception(
f"Frame type error; expected height {self._spec._fmt['height']}, got {frame._fmt['height']}"
)
if frame._fmt["pix_fmt"] != self._spec._fmt["pix_fmt"]:
f_obj = _filter_scale(frame._f, pix_fmt=self._spec._fmt["pix_fmt"])
frame = Frame(f_obj, self._spec._fmt)
Expand Down
39 changes: 39 additions & 0 deletions vidformer-py/vidformer/igni/vf_igni.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ def delete_source(self, id: str):
response = response.json()
assert response["status"] == "ok"

def search_source(self, name, stream_idx, storage_service, storage_config):
assert type(name) == str
assert type(stream_idx) == int
assert type(storage_service) == str
assert type(storage_config) == dict
for k, v in storage_config.items():
assert type(k) == str
assert type(v) == str
req = {
"name": name,
"stream_idx": stream_idx,
"storage_service": storage_service,
"storage_config": storage_config,
}
response = requests.post(
f"{self._endpoint}/source/search",
json=req,
headers={"Authorization": f"Bearer {self._api_key}"},
)
if not response.ok:
raise Exception(response.text)
response = response.json()
return response

def create_source(self, name, stream_idx, storage_service, storage_config):
assert type(name) == str
assert type(stream_idx) == int
Expand All @@ -81,6 +105,18 @@ def create_source(self, name, stream_idx, storage_service, storage_config):
id = response["id"]
return self.get_source(id)

def source(self, name, stream_idx, storage_service, storage_config):
"""Convenience function for accessing sources.
Tries to find a source with the given name, stream_idx, storage_service, and storage_config.
If no source is found, creates a new source with the given parameters.
"""

sources = self.search_source(name, stream_idx, storage_service, storage_config)
if len(sources) == 0:
return self.create_source(name, stream_idx, storage_service, storage_config)
return self.get_source(sources[0])

def get_spec(self, id: str):
assert type(id) == str
response = requests.get(
Expand Down Expand Up @@ -216,6 +252,9 @@ def __getitem__(self, idx):
raise Exception("Source index must be a Fraction")
return vf.SourceExpr(self, idx, False)

def __repr__(self):
return f"IgniSource({self._name})"


class IgniSpec:
def __init__(self, id, src):
Expand Down
13 changes: 13 additions & 0 deletions viper-den/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def test_list_sources():
assert source_id in resp


def test_search_source():
source_id = _create_tos_source()
response = requests.post(
ENDPOINT + "v2/source/search", headers=AUTH_HEADERS, json=TOS_SOURCE
)
response.raise_for_status()
resp = response.json()
assert type(resp) == list
for sid in resp:
assert type(sid) == str
assert source_id in resp


def test_delete_source():
source_id = _create_tos_source()
response = requests.delete(
Expand Down
40 changes: 39 additions & 1 deletion viper-den/test_python_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_connect():
server = igni.IgniServer(ENDPOINT, API_KEY)


def test_source():
def test_create_source():
server = igni.IgniServer(ENDPOINT, API_KEY)
tos = server.create_source("../tos_720p.mp4", 0, "fs", {"root": "."})
assert isinstance(tos, igni.IgniSource)
Expand All @@ -21,6 +21,34 @@ def test_source():
assert isinstance(t, Fraction)


def test_source():
server = igni.IgniServer(ENDPOINT, API_KEY)

# delete all specs first (since they depend on sources)
specs = server.list_specs()
for spec in specs:
server.delete_spec(spec)

# delete all sources first
sources = server.list_sources()
for source in sources:
server.delete_source(source)

# Get a source which doesn't exist
tos = server.source("../tos_720p.mp4", 0, "fs", {"root": "."})
assert isinstance(tos, igni.IgniSource)

# Get a source which already exists
tos2 = server.source("../tos_720p.mp4", 0, "fs", {"root": "."})
assert isinstance(tos2, igni.IgniSource)

assert tos.id() == tos2.id()

# check only one source exists
sources = server.list_sources()
assert len(sources) == 1


def test_list_sources():
server = igni.IgniServer(ENDPOINT, API_KEY)
tos = server.create_source("../tos_720p.mp4", 0, "fs", {"root": "."})
Expand All @@ -30,6 +58,16 @@ def test_list_sources():
assert tos.id() in sources


def test_search_source():
server = igni.IgniServer(ENDPOINT, API_KEY)
tos = server.create_source("../tos_720p.mp4", 0, "fs", {"root": "."})
matching_sources = server.search_source("../tos_720p.mp4", 0, "fs", {"root": "."})
assert type(matching_sources) == list
for source in matching_sources:
assert isinstance(source, str)
assert tos.id() in matching_sources


def test_delete_source():
server = igni.IgniServer(ENDPOINT, API_KEY)
tos = server.create_source("../tos_720p.mp4", 0, "fs", {"root": "."})
Expand Down

0 comments on commit c05a6b8

Please sign in to comment.