From 4a5a7ac83959f12c7caa951aa0711d96aecfa792 Mon Sep 17 00:00:00 2001 From: Firepup650 <70233190+Firepup6500@users.noreply.github.com> Date: Tue, 27 Feb 2024 18:06:57 -0600 Subject: [PATCH] [FIX] Prevent keys from starting with slashes (#172) * [FIX] Prevent keys from starting with slashes If a key starts with a slash, then it becomes undeletable and prevents database purges from working properly as well. This prevents that from occuring by stripping slashes from the left of the key name. * Double newlines for flake8 * flake8 wanted another newline here * Force `set_bulk_raw` to handle keys with slashes as well * Add tests for keys starting with a slash * Fix a typo I made twice * flake8 * `del self.db[k]` not `self.db.delete(k)` in non-Async * One space for flake8 * These were also wrong * These shouldn't be using `get` * Match format of some of the other tests in TestDatabase * Perhaps the key is corrupted? * Have to `get_raw` for `_raw` calls. * Reassociate _dumps with def dumps * Only call keyStrip at the root of the .set function hierarchy * Clarify that keyStrip is an internal method --------- Co-authored-by: Devon Stewart --- src/replit/database/database.py | 14 +++++++++ tests/test_database.py | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/src/replit/database/database.py b/src/replit/database/database.py index 025e676..41cd1b6 100644 --- a/src/replit/database/database.py +++ b/src/replit/database/database.py @@ -61,6 +61,18 @@ def dumps(val: Any) -> str: _dumps = dumps +def _sanitize_key(key: str) -> str: + """Strip slashes from the beginning of keys. + + Args: + key (str): The key to strip + + Returns: + str: The stripped key + """ + return key.lstrip("/") + + class AsyncDatabase: """Async interface for Replit Database. @@ -195,6 +207,7 @@ async def set_bulk_raw(self, values: Dict[str, str]) -> None: Args: values (Dict[str, str]): The key-value pairs to set. """ + values = {_sanitize_key(k): v for k, v in values.items()} async with self.client.post(self.db_url, data=values) as response: response.raise_for_status() @@ -629,6 +642,7 @@ def set_bulk_raw(self, values: Dict[str, str]) -> None: Args: values (Dict[str, str]): The key-value pairs to set. """ + values = {_sanitize_key(k): v for k, v in values.items()} r = self.sess.post(self.db_url, data=values) r.raise_for_status() diff --git a/tests/test_database.py b/tests/test_database.py index 8c5dce0..0f57f15 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -125,6 +125,33 @@ async def test_bulk_raw(self) -> None: self.assertEqual(await self.db.get_raw("bulk1"), "val1") self.assertEqual(await self.db.get_raw("bulk2"), "val2") + async def test_slash_keys(self) -> None: + """Test that slash keys work.""" + k = "/key" + # set + await self.db.set(k,"val1") + self.assertEqual(await self.db.get(k), "val1") + await self.db.delete(k) + with self.assertRaises(KeyError): + await self.db.get(k) + # set_raw + await self.db.set_raw(k,"val1") + self.assertEqual(await self.db.get_raw(k), "val1") + await self.db.delete(k) + with self.assertRaises(KeyError): + await self.db.get(k) + # set_bulk + await self.db.set_bulk({k: "val1"}) + self.assertEqual(await self.db.get(k), "val1") + await self.db.delete(k) + with self.assertRaises(KeyError): + await self.db.get(k) + # set_bulk_raw + await self.db.set_bulk_raw({k: "val1"}) + self.assertEqual(await self.db.get_raw(k), "val1") + await self.db.delete(k) + with self.assertRaises(KeyError): + await self.db.get(k) class TestDatabase(unittest.TestCase): """Tests for replit.database.Database.""" @@ -259,3 +286,31 @@ def test_bulk_raw(self) -> None: self.db.set_bulk_raw({"bulk1": "val1", "bulk2": "val2"}) self.assertEqual(self.db.get_raw("bulk1"), "val1") self.assertEqual(self.db.get_raw("bulk2"), "val2") + + def test_slash_keys(self) -> None: + """Test that slash keys work.""" + k = "/key" + # set + self.db.set(k,"val1") + self.assertEqual(self.db[k], "val1") + del self.db[k] + with self.assertRaises(KeyError): + self.db[k] + # set_raw + self.db.set_raw(k,"val1") + self.assertEqual(self.db.get_raw(k), "val1") + del self.db[k] + with self.assertRaises(KeyError): + self.db[k] + # set_bulk + self.db.set_bulk({k: "val1"}) + self.assertEqual(self.db.get(k), "val1") + del self.db[k] + with self.assertRaises(KeyError): + self.db[k] + # set_bulk_raw + self.db.set_bulk_raw({k: "val1"}) + self.assertEqual(self.db.get_raw(k), "val1") + del self.db[k] + with self.assertRaises(KeyError): + self.db[k]