Skip to content

Commit

Permalink
py: stream rework
Browse files Browse the repository at this point in the history
  • Loading branch information
jordens committed Jan 20, 2025
1 parent a36e294 commit a89e174
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 49 deletions.
36 changes: 17 additions & 19 deletions hitl/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os

import miniconf
from stabilizer.stream import measure, StabilizerStream, get_local_ip
from stabilizer.stream import measure, Stream, get_local_ip

logger = logging.getLogger(__name__)

Expand All @@ -22,12 +22,13 @@
async def _main():
parser = argparse.ArgumentParser(description="Stabilizer Stream HITL test")
parser.add_argument(
"prefix", type=str, nargs="?", help="The MQTT topic prefix of the target"
"prefix",
help="The MQTT topic prefix of the target",
)
parser.add_argument(
"--broker", "-b", default="mqtt", type=str, help="The MQTT broker address"
)
parser.add_argument("--ip", default="0.0.0.0", help="The IP address to listen on")
parser.add_argument("--addr", default="0.0.0.0", help="The IP address to listen on")
parser.add_argument(
"--port", type=int, default=9293, help="Local port to listen on"
)
Expand All @@ -36,39 +37,36 @@ async def _main():
"--max-loss", type=float, default=5e-2, help="Maximum loss for success"
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)

async with miniconf.Client(
args.broker,
protocol=miniconf.MQTTv5,
logger=logging.getLogger("aiomqtt-client"),
) as client:
prefix = args.prefix
if not args.prefix:
prefix, _alive = miniconf.one(
await miniconf.discover(client, "dt/sinara/dual-iir/+")
)

logging.basicConfig(level=logging.INFO)

prefix, _alive = miniconf.one(await miniconf.discover(client, args.prefix))
conf = miniconf.Miniconf(client, prefix)

if ipaddress.ip_address(args.ip).is_unspecified:
args.ip = get_local_ip(args.broker)
if ipaddress.ip_address(args.addr).is_unspecified:
args.addr = get_local_ip(args.broker)
if ipaddress.ip_address(args.addr).is_multicast:
local = get_local_ip(args.broker)
else:
local = "0.0.0.0"

logger.info("Starting stream")
await conf.set("/stream", f"{args.ip}:{args.port}", retain=False)
await conf.set("/stream", f"{args.addr}:{args.port}")

try:
logger.info("Testing stream reception")
_transport, stream = await StabilizerStream.open(
args.port, addr=args.ip, bind=get_local_ip(args.broker)
_transport, stream = await Stream.open(
args.port, addr=args.addr, local=local
)
logger.info("Testing stream reception")
loss = await measure(stream, args.duration)
if loss > args.max_loss:
raise RuntimeError("High frame loss", loss)
finally:
logger.info("Stopping stream")
await conf.set("/stream", "0.0.0.0:0", retain=False)
await conf.set("/stream", "0.0.0.0:0")

logger.info("Draining queue")
await asyncio.sleep(0.1)
Expand Down
72 changes: 42 additions & 30 deletions py/stabilizer/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_local_ip(remote):
Returns a list of four octets."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.connect((remote, 1883))
sock.connect((remote, 9)) # discard
return sock.getsockname()[0]
finally:
sock.close()
Expand Down Expand Up @@ -123,39 +123,41 @@ def parse(cls, data):
return parser(header, data[cls.header_fmt.size :])


class StabilizerStream(asyncio.DatagramProtocol):
class Stream(asyncio.DatagramProtocol):
"""Stabilizer streaming receiver protocol"""

@classmethod
async def open(cls, port=9293, addr="0.0.0.0", bind=None, maxsize=1):
async def open(cls, port=9293, addr="0.0.0.0", local="0.0.0.0", maxsize=1):
"""Open a UDP socket and start receiving frames"""
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

# Increase the OS UDP receive buffer size to 16 MiB so that latency
# spikes don't impact much. Achieving 16 MiB may require increasing
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
# Increase the OS UDP receive buffer size so that latency
# spikes don't impact much. Achieving this may require increasing
# the max allowed buffer size, e.g. via
# `sudo sysctl net.core.rmem_max=26214400` but nowadays the default
# max appears to be ~ 50 MiB already.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 16 << 20)

# max appears to be ~ 50 MiB already, at least on Linux.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 8 << 20)
# We need to specify which interface to receive multicasts from, or Windows may choose the
# wrong one. Thus, use the broker address to figure out our local address for the interface
# of interest.
# wrong one. Thus, use a bind address to figure out our local address for the interface
# of interest. There's also an interface index, at least on linux, but apparently windows
# sockets don't do that.
if ipaddress.ip_address(addr).is_multicast:
multiaddr = socket.inet_aton(addr)
local = socket.inet_aton(local)
sock.setsockopt(
socket.IPPROTO_IP,
socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(addr) + socket.inet_aton(bind),
multiaddr + local,
)
sock.bind(("", port))
else:
sock.bind((addr, port))

print(f"bind to {addr}")
sock.bind((addr, port))
transport, protocol = await loop.create_datagram_endpoint(
lambda: cls(maxsize), sock=sock
lambda: cls(maxsize),
sock=sock,
)

return transport, protocol

def __init__(self, maxsize):
Expand Down Expand Up @@ -199,8 +201,6 @@ async def _record():
stat.received += frame.header.batches
stat.expect = wrap(frame.header.sequence + frame.header.batches)
stat.bytes += frame.size()
# test conversion
# frame.to_si()

try:
await asyncio.wait_for(_record(), timeout=duration)
Expand All @@ -212,10 +212,7 @@ async def _record():
)

sent = stat.received + stat.lost
if sent:
loss = stat.lost / sent
else:
loss = 1
loss = stat.lost / sent if sent else 1
logger.info("Loss: %s/%s batches (%g %%)", stat.lost, sent, loss * 1e2)
return loss

Expand All @@ -224,17 +221,32 @@ async def main():
"""Test CLI"""
parser = argparse.ArgumentParser(description="Stabilizer streaming demo")
parser.add_argument(
"--port", type=int, default=9293, help="Local port to listen on"
"--port", type=int, default=9293, help="Local port to listen on [%(default)s]"
)
parser.add_argument(
"--host", default="0.0.0.0", help="Local address to listen on [%(default)s]"
)
parser.add_argument(
"--local",
default="0.0.0.0",
help="The local IP address to receive multicast frames on [%(default)s]",
)
parser.add_argument(
"--broker", help="The MQTT broker address for local IP lookup [%(default)s]"
)
parser.add_argument(
"--maxsize", type=int, default=1, help="Frame queue size [%(default)s]"
)
parser.add_argument(
"--duration", type=float, default=1.0, help="Test duration [%(default)s]"
)
parser.add_argument("--host", default="0.0.0.0", help="Local address to listen on")
parser.add_argument("--broker", default="mqtt", help="The MQTT broker address")
parser.add_argument("--maxsize", type=int, default=1, help="Frame queue size")
parser.add_argument("--duration", type=float, default=1.0, help="Test duration")
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
_transport, stream = await StabilizerStream.open(
args.port, args.host, get_local_ip(args.broker), args.maxsize
if args.broker is not None:
args.local = get_local_ip(args.broker)
_transport, stream = await Stream.open(
args.port, args.host, args.local, args.maxsize
)
await measure(stream, args.duration)

Expand Down

0 comments on commit a89e174

Please sign in to comment.