Skip to content

Commit

Permalink
Fix behavior for default-only tagged entity fields (#102)
Browse files Browse the repository at this point in the history
This enables the Java tests for the last non-record entity:
FetchRequest. This was previously failing because on the Java side, the
ReplicaState field tagged field is omitted when all of the nested fields
are equal to defaults.

Because our logic already has the correct behavior in general for tagged
fields, all that was needed here was to make such parent fields have a
default value. We detect when we can set a default by checking if the
nested entity has default values for all of its fields.

Partially addresses #100.
  • Loading branch information
aiven-anton authored Dec 11, 2023
1 parent 28e3ea5 commit cc3ec0a
Show file tree
Hide file tree
Showing 23 changed files with 172 additions and 14 deletions.
31 changes: 31 additions & 0 deletions codegen/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def format_dataclass_field(
custom_type: CustomTypeDef | None,
tag: int | None,
ignorable: bool,
nested_entity_defaults_only: bool = False,
) -> str:
metadata: dict[str, object] = {}
inner_type = (
Expand All @@ -170,6 +171,18 @@ def format_dataclass_field(
field_kwargs["default"] = format_default(
field_type, default, optional, custom_type
)
elif (
tag is not None
and isinstance(field_type, EntityType | CommonStructType)
and nested_entity_defaults_only
):
# As of writing, this caters to a single field in the schema, the v15
# FetchRequest.ReplicaState. When values of the nested entity are all defaults,
# the tagged field is expected to be omitted. By making the default value of the
# parent field equal to instantiating the child with only defaults, this doesn't
# need any special treatment in parsers/serializers and functions as other
# tagged fields in this respect.
field_kwargs["default"] = f"{field_type}()"
elif tag is not None and ignorable:
field_kwargs["default"] = _format_default_for_tagged(field_type)

Expand Down Expand Up @@ -358,6 +371,23 @@ def entity_annotation(field: EntityField | CommonStructField, optional: bool) ->
return f"{field.type} | None" if optional else str(field.type)


def nested_entity_has_only_defaults(field: EntityField | CommonStructField) -> bool:
# TODO: This behavior should likely apply to a tagged to a CommonStructField as
# well. For now we don't have the required introspection capabilities of its
# fields, so that's left for when it becomes required.
return isinstance(field, EntityField) and all(
not isinstance(
field,
PrimitiveArrayField
| EntityArrayField
| CommonStructArrayField
| CommonStructField,
)
and field.default is not None
for field in field.fields
)


def generate_entity_field(
field: EntityField | CommonStructField,
version: int,
Expand All @@ -372,6 +402,7 @@ def generate_entity_field(
custom_type=None,
tag=field.get_tag(version),
ignorable=field.ignorable,
nested_entity_defaults_only=nested_entity_has_only_defaults(field),
)
annotation = entity_annotation(field, optional)
return f" {to_snake_case(field.name)}: {annotation}{field_call}\n"
Expand Down
1 change: 0 additions & 1 deletion codegen/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def main() -> None:
"ProduceRequest", # Records
"FetchResponse", # Records
"FetchSnapshotResponse", # Records
"FetchRequest", # Should not output tagged field if its value equals to default (presumably)
}:
module_code[module_path].append(
test_code_java.format(
Expand Down
10 changes: 7 additions & 3 deletions src/kio/schema/fetch/v12/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ class PartitionData:
"""The last stable offset (or LSO) of the partition. This is the last offset such that the state of all transactional records prior to this offset have been decided (ABORTED or COMMITTED)"""
log_start_offset: i64 = field(metadata={"kafka_type": "int64"}, default=i64(-1))
"""The current log start offset."""
diverging_epoch: EpochEndOffset = field(metadata={"tag": 0})
diverging_epoch: EpochEndOffset = field(
metadata={"tag": 0}, default=EpochEndOffset()
)
"""In case divergence is detected based on the `LastFetchedEpoch` and `FetchOffset` in the request, this field indicates the largest epoch and its end offset such that subsequent records are known to diverge"""
current_leader: LeaderIdAndEpoch = field(metadata={"tag": 1})
snapshot_id: SnapshotId = field(metadata={"tag": 2})
current_leader: LeaderIdAndEpoch = field(
metadata={"tag": 1}, default=LeaderIdAndEpoch()
)
snapshot_id: SnapshotId = field(metadata={"tag": 2}, default=SnapshotId())
"""In the case of fetching an offset less than the LogStartOffset, this is the end offset and epoch that should be used in the FetchSnapshot request."""
aborted_transactions: tuple[AbortedTransaction, ...]
"""The aborted transactions."""
Expand Down
10 changes: 7 additions & 3 deletions src/kio/schema/fetch/v13/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ class PartitionData:
"""The last stable offset (or LSO) of the partition. This is the last offset such that the state of all transactional records prior to this offset have been decided (ABORTED or COMMITTED)"""
log_start_offset: i64 = field(metadata={"kafka_type": "int64"}, default=i64(-1))
"""The current log start offset."""
diverging_epoch: EpochEndOffset = field(metadata={"tag": 0})
diverging_epoch: EpochEndOffset = field(
metadata={"tag": 0}, default=EpochEndOffset()
)
"""In case divergence is detected based on the `LastFetchedEpoch` and `FetchOffset` in the request, this field indicates the largest epoch and its end offset such that subsequent records are known to diverge"""
current_leader: LeaderIdAndEpoch = field(metadata={"tag": 1})
snapshot_id: SnapshotId = field(metadata={"tag": 2})
current_leader: LeaderIdAndEpoch = field(
metadata={"tag": 1}, default=LeaderIdAndEpoch()
)
snapshot_id: SnapshotId = field(metadata={"tag": 2}, default=SnapshotId())
"""In the case of fetching an offset less than the LogStartOffset, this is the end offset and epoch that should be used in the FetchSnapshot request."""
aborted_transactions: tuple[AbortedTransaction, ...]
"""The aborted transactions."""
Expand Down
10 changes: 7 additions & 3 deletions src/kio/schema/fetch/v14/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ class PartitionData:
"""The last stable offset (or LSO) of the partition. This is the last offset such that the state of all transactional records prior to this offset have been decided (ABORTED or COMMITTED)"""
log_start_offset: i64 = field(metadata={"kafka_type": "int64"}, default=i64(-1))
"""The current log start offset."""
diverging_epoch: EpochEndOffset = field(metadata={"tag": 0})
diverging_epoch: EpochEndOffset = field(
metadata={"tag": 0}, default=EpochEndOffset()
)
"""In case divergence is detected based on the `LastFetchedEpoch` and `FetchOffset` in the request, this field indicates the largest epoch and its end offset such that subsequent records are known to diverge"""
current_leader: LeaderIdAndEpoch = field(metadata={"tag": 1})
snapshot_id: SnapshotId = field(metadata={"tag": 2})
current_leader: LeaderIdAndEpoch = field(
metadata={"tag": 1}, default=LeaderIdAndEpoch()
)
snapshot_id: SnapshotId = field(metadata={"tag": 2}, default=SnapshotId())
"""In the case of fetching an offset less than the LogStartOffset, this is the end offset and epoch that should be used in the FetchSnapshot request."""
aborted_transactions: tuple[AbortedTransaction, ...]
"""The aborted transactions."""
Expand Down
2 changes: 1 addition & 1 deletion src/kio/schema/fetch/v15/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class FetchRequest(ApiMessage):
metadata={"kafka_type": "string", "tag": 0}, default=None
)
"""The clusterId if known. This is used to validate metadata fetches prior to broker registration."""
replica_state: ReplicaState = field(metadata={"tag": 1})
replica_state: ReplicaState = field(metadata={"tag": 1}, default=ReplicaState())
max_wait: i32Timedelta = field(metadata={"kafka_type": "timedelta_i32"})
"""The maximum time in milliseconds to wait for the response."""
min_bytes: i32 = field(metadata={"kafka_type": "int32"})
Expand Down
10 changes: 7 additions & 3 deletions src/kio/schema/fetch/v15/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ class PartitionData:
"""The last stable offset (or LSO) of the partition. This is the last offset such that the state of all transactional records prior to this offset have been decided (ABORTED or COMMITTED)"""
log_start_offset: i64 = field(metadata={"kafka_type": "int64"}, default=i64(-1))
"""The current log start offset."""
diverging_epoch: EpochEndOffset = field(metadata={"tag": 0})
diverging_epoch: EpochEndOffset = field(
metadata={"tag": 0}, default=EpochEndOffset()
)
"""In case divergence is detected based on the `LastFetchedEpoch` and `FetchOffset` in the request, this field indicates the largest epoch and its end offset such that subsequent records are known to diverge"""
current_leader: LeaderIdAndEpoch = field(metadata={"tag": 1})
snapshot_id: SnapshotId = field(metadata={"tag": 2})
current_leader: LeaderIdAndEpoch = field(
metadata={"tag": 1}, default=LeaderIdAndEpoch()
)
snapshot_id: SnapshotId = field(metadata={"tag": 2}, default=SnapshotId())
"""In the case of fetching an offset less than the LogStartOffset, this is the end offset and epoch that should be used in the FetchSnapshot request."""
aborted_transactions: tuple[AbortedTransaction, ...]
"""The aborted transactions."""
Expand Down
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v0_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v0.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v10_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v10.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v11_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v11.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v12_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v12.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v13_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v13.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v14_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v14.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v15_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from kio.schema.fetch.v15.request import ReplicaState
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_replica_state: Final = entity_reader(ReplicaState)
Expand Down Expand Up @@ -89,3 +90,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v1_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v1.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v2_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v2.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v3_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v3.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v4_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v4.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v5_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v5.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v6_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kio.schema.fetch.v6.request import FetchTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -57,3 +58,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
7 changes: 7 additions & 0 deletions tests/generated/test_fetch_v7_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kio.schema.fetch.v7.request import ForgottenTopic
from kio.serial import entity_reader
from kio.serial import entity_writer
from tests.conftest import JavaTester
from tests.conftest import setup_buffer

read_fetch_partition: Final = entity_reader(FetchPartition)
Expand Down Expand Up @@ -73,3 +74,9 @@ def test_fetch_request_roundtrip(instance: FetchRequest) -> None:
buffer.seek(0)
result = read_fetch_request(buffer)
assert instance == result


@pytest.mark.java
@given(instance=from_type(FetchRequest))
def test_fetch_request_java(instance: FetchRequest, java_tester: JavaTester) -> None:
java_tester.test(instance)
Loading

0 comments on commit cc3ec0a

Please sign in to comment.