Skip to content

Commit

Permalink
separate KvID to KeyId and ValueId (#19180)
Browse files Browse the repository at this point in the history
<!-- Merging Requirements:
- Please give your PR a title that is release-note friendly
- In order to be merged, you must add the most appropriate category
Label (Added, Changed, Fixed) to your PR
-->
<!-- Explain why this is an improvement (Does this add missing
functionality, improve performance, or reduce complexity?) -->

### Purpose:

<!-- Does this PR introduce a breaking change? -->

### Current Behavior:

### New Behavior:

<!-- As we aim for complete code coverage, please include details
regarding unit, and regression tests -->

### Testing Notes:

<!-- Attach any visual examples, or supporting evidence (attach any
.gif/video/console output below) -->
  • Loading branch information
altendky authored Jan 25, 2025
2 parents bb408a8 + cfd08e2 commit 137bd80
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 72 deletions.
48 changes: 26 additions & 22 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from chia.data_layer.data_layer_util import InternalNode, Side, internal_hash
from chia.data_layer.util.merkle_blob import (
InvalidIndexError,
KVId,
KeyId,
KeyOrValueId,
MerkleBlob,
NodeMetadata,
NodeType,
RawInternalMerkleNode,
RawLeafMerkleNode,
RawMerkleNodeProtocol,
TreeIndex,
ValueId,
data_size,
metadata_size,
null_parent,
Expand Down Expand Up @@ -135,8 +137,8 @@ def id(self) -> str:
raw=RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
key=KVId(0x2425262728292A2B),
value=KVId(0x2C2D2E2F30313233),
key=KeyId(KeyOrValueId(0x2425262728292A2B)),
value=ValueId(KeyOrValueId(0x2C2D2E2F30313233)),
),
),
]
Expand Down Expand Up @@ -169,8 +171,8 @@ def test_merkle_blob_one_leaf_loads() -> None:
leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
)
blob = bytearray(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(leaf))

Expand All @@ -190,14 +192,14 @@ def test_merkle_blob_two_leafs_loads() -> None:
left_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
key=KeyId(KeyOrValueId(0x0405060708090A0B)),
value=ValueId(KeyOrValueId(0x0405060708090A1B)),
)
right_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x1415161718191A1B),
value=KVId(0x1415161718191A2B),
key=KeyId(KeyOrValueId(0x1415161718191A1B)),
value=ValueId(KeyOrValueId(0x1415161718191A2B)),
)
blob = bytearray()
blob.extend(NodeMetadata(type=NodeType.internal, dirty=True).pack() + pack_raw_node(root))
Expand All @@ -218,20 +220,20 @@ def test_merkle_blob_two_leafs_loads() -> None:
son_hash = bytes32(range(32))
root_hash = internal_hash(son_hash, son_hash)
expected_node = InternalNode(root_hash, son_hash, son_hash)
assert merkle_blob.get_lineage_by_key_id(KVId(0x0405060708090A0B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KVId(0x1415161718191A1B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x0405060708090A0B))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KeyId(KeyOrValueId(0x1415161718191A1B))) == [expected_node]


def generate_kvid(seed: int) -> tuple[KVId, KVId]:
kv_ids: list[KVId] = []
def generate_kvid(seed: int) -> tuple[KeyId, ValueId]:
kv_ids: list[KeyOrValueId] = []

for offset in range(2):
seed_bytes = (2 * seed + offset).to_bytes(8, byteorder="big", signed=True)
hash_obj = hashlib.sha256(seed_bytes)
hash_int = int.from_bytes(hash_obj.digest()[:8], byteorder="big", signed=True)
kv_ids.append(KVId(hash_int))
kv_ids.append(KeyOrValueId(hash_int))

return kv_ids[0], kv_ids[1]
return KeyId(kv_ids[0]), ValueId(kv_ids[1])


def generate_hash(seed: int) -> bytes32:
Expand All @@ -245,7 +247,7 @@ def test_insert_delete_loads_all_keys() -> None:
num_keys = 200000
extra_keys = 100000
max_height = 25
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}

random = Random()
random.seed(100, version=2)
Expand Down Expand Up @@ -304,7 +306,7 @@ def test_small_insert_deletes() -> None:

for repeats in range(num_repeats):
for num_inserts in range(1, max_inserts):
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}
for inserts in range(num_inserts):
seed += 1
key, value = generate_kvid(seed)
Expand All @@ -330,13 +332,13 @@ def test_proof_of_inclusion_merkle_blob() -> None:
random.seed(100, version=2)

merkle_blob = MerkleBlob(blob=bytearray())
keys_values: dict[KVId, KVId] = {}
keys_values: dict[KeyId, ValueId] = {}

for repeats in range(num_repeats):
num_inserts = 1 + repeats * 100
num_deletes = 1 + repeats * 10

kv_ids: list[tuple[KVId, KVId]] = []
kv_ids: list[tuple[KeyId, ValueId]] = []
hashes: list[bytes32] = []
for _ in range(num_inserts):
seed += 1
Expand All @@ -363,7 +365,7 @@ def test_proof_of_inclusion_merkle_blob() -> None:
with pytest.raises(Exception, match=f"Key {kv_id} not present in the store"):
merkle_blob.get_proof_of_inclusion(kv_id)

new_keys_values: dict[KVId, KVId] = {}
new_keys_values: dict[KeyId, ValueId] = {}
for old_kv in keys_values.keys():
seed += 1
_, value = generate_kvid(seed)
Expand All @@ -382,7 +384,9 @@ def test_proof_of_inclusion_merkle_blob() -> None:
@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(-1), TreeIndex(1), TreeIndex(null_parent)])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
merkle_blob.insert(KVId(0x1415161718191A1B), KVId(0x1415161718191A1B), bytes(range(12, data_size)))
merkle_blob.insert(
KeyId(KeyOrValueId(0x1415161718191A1B)), ValueId(KeyOrValueId(0x1415161718191A1B)), bytes(range(12, data_size))
)

with pytest.raises(InvalidIndexError):
merkle_blob.get_raw_node(index)
Expand Down Expand Up @@ -497,6 +501,6 @@ def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
total_time = 0.0
for i in range(100000):
start = time.monotonic()
merkle_blob.insert(KVId(i), KVId(i), HASH)
merkle_blob.insert(KeyId(KeyOrValueId(i)), ValueId(KeyOrValueId(i)), HASH)
end = time.monotonic()
total_time += end - start
52 changes: 28 additions & 24 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
unspecified,
)
from chia.data_layer.util.merkle_blob import (
KVId,
KeyId,
KeyOrValueId,
MerkleBlob,
RawInternalMerkleNode,
RawLeafMerkleNode,
TreeIndex,
ValueId,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.batches import to_batches
Expand Down Expand Up @@ -191,7 +193,7 @@ async def insert_into_data_store_from_file(
filename: Path,
) -> None:
internal_nodes: dict[bytes32, tuple[bytes32, bytes32]] = {}
terminal_nodes: dict[bytes32, tuple[KVId, KVId]] = {}
terminal_nodes: dict[bytes32, tuple[KeyId, ValueId]] = {}

with open(filename, "rb") as reader:
while True:
Expand Down Expand Up @@ -409,7 +411,7 @@ async def insert_root_from_merkle_blob(

return await self._insert_root(store_id, root_hash, status)

async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]:
async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KeyOrValueId]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT kv_id FROM ids WHERE blob = ? AND store_id = ?",
Expand All @@ -423,9 +425,9 @@ async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]:
if row is None:
return None

return KVId(row[0])
return KeyOrValueId(row[0])

async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[bytes]:
async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Optional[bytes]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT blob FROM ids WHERE kv_id = ? AND store_id = ?",
Expand All @@ -441,15 +443,15 @@ async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[b

return bytes(row[0])

async def get_terminal_node(self, kid: KVId, vid: KVId, store_id: bytes32) -> TerminalNode:
async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -> TerminalNode:
key = await self.get_blob_from_kvid(kid, store_id)
value = await self.get_blob_from_kvid(vid, store_id)
if key is None or value is None:
raise Exception("Cannot find the key/value pair")

return TerminalNode(hash=leaf_hash(key, value), key=key, value=value)

async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId:
async def add_kvid(self, blob: bytes, store_id: bytes32) -> KeyOrValueId:
kv_id = await self.get_kvid(blob, store_id)
if kv_id is not None:
return kv_id
Expand All @@ -468,9 +470,9 @@ async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId:
raise Exception("Internal error")
return kv_id

async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KVId, KVId]:
kid = await self.add_kvid(key, store_id)
vid = await self.add_kvid(value, store_id)
async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KeyId, ValueId]:
kid = KeyId(await self.add_kvid(key, store_id))
vid = ValueId(await self.add_kvid(value, store_id))
hash = leaf_hash(key, value)
async with self.db_wrapper.writer() as writer:
await writer.execute(
Expand All @@ -484,7 +486,7 @@ async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tu
)
return (kid, vid)

async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId, KVId]:
async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KeyId, ValueId]:
async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
"SELECT * FROM hashes WHERE hash = ? AND store_id = ?",
Expand All @@ -499,8 +501,8 @@ async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId
if row is None:
raise Exception(f"Cannot find node by hash {hash.hex()}")

kid = KVId(row["kid"])
vid = KVId(row["vid"])
kid = KeyId(row["kid"])
vid = ValueId(row["vid"])
return (kid, vid)

async def get_terminal_node_by_hash(self, node_hash: bytes32, store_id: bytes32) -> TerminalNode:
Expand Down Expand Up @@ -1057,19 +1059,19 @@ async def get_keys(
for kid in kv_ids.keys():
key = await self.get_blob_from_kvid(kid, store_id)
if key is None:
raise Exception(f"Unknown key corresponding to KVId: {kid}")
raise Exception(f"Unknown key corresponding to KeyId: {kid}")
keys.append(key)

return keys

def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KVId, Side]:
def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KeyId, Side]:
side_seed = bytes(seed)[0]
side = Side.LEFT if side_seed < 128 else Side.RIGHT
reference_node = merkle_blob.get_random_leaf_node(seed)
kid = reference_node.key
return (kid, side)

async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KVId, store_id: bytes32) -> TerminalNode:
async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KeyId, store_id: bytes32) -> TerminalNode:
index = merkle_blob.key_to_index[kid]
raw_node = merkle_blob.get_raw_node(index)
assert isinstance(raw_node, RawLeafMerkleNode)
Expand Down Expand Up @@ -1139,7 +1141,7 @@ async def delete(

kid = await self.get_kvid(key, store_id)
if kid is not None:
merkle_blob.delete(kid)
merkle_blob.delete(KeyId(kid))

new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status)

Expand Down Expand Up @@ -1199,7 +1201,7 @@ async def insert_batch(
first_action[hash] = change["action"]
last_action[hash] = change["action"]

batch_keys_values: list[tuple[KVId, KVId]] = []
batch_keys_values: list[tuple[KeyId, ValueId]] = []
batch_hashes: list[bytes32] = []

for change in changelist:
Expand All @@ -1209,7 +1211,7 @@ async def insert_batch(

reference_node_hash = change.get("reference_node_hash", None)
side = change.get("side", None)
reference_kid: Optional[KVId] = None
reference_kid: Optional[KeyId] = None
if reference_node_hash is not None:
reference_kid, _ = await self.get_node_by_hash(reference_node_hash, store_id)

Expand All @@ -1236,7 +1238,7 @@ async def insert_batch(
key = change["key"]
deletion_kid = await self.get_kvid(key, store_id)
if deletion_kid is not None:
merkle_blob.delete(deletion_kid)
merkle_blob.delete(KeyId(deletion_kid))
elif change["action"] == "upsert":
key = change["key"]
new_value = change["value"]
Expand Down Expand Up @@ -1324,9 +1326,10 @@ async def get_node_by_key(
except MerkleBlobNotFoundError:
raise KeyNotFoundError(key=key)

kid = await self.get_kvid(key, store_id)
if kid is None:
kvid = await self.get_kvid(key, store_id)
if kvid is None:
raise KeyNotFoundError(key=key)
kid = KeyId(kvid)
if not merkle_blob.key_exists(kid):
raise KeyNotFoundError(key=key)
return await self.get_terminal_node_from_kid(merkle_blob, kid, store_id)
Expand Down Expand Up @@ -1389,9 +1392,10 @@ async def get_proof_of_inclusion_by_key(
) -> ProofOfInclusion:
root = await self.get_tree_root(store_id=store_id)
merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash)
kid = await self.get_kvid(key, store_id)
if kid is None:
kvid = await self.get_kvid(key, store_id)
if kvid is None:
raise Exception(f"Cannot find key: {key.hex()}")
kid = KeyId(kvid)
return merkle_blob.get_proof_of_inclusion(kid)

async def write_tree_to_file(
Expand Down
Loading

0 comments on commit 137bd80

Please sign in to comment.