diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index dbf3a44736c..45f95485c43 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -94,7 +94,7 @@ """ # Define keys for arg type checks -CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"} +CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"} CFG_FRACTION_KEYS = { "dropout", "iou", @@ -132,7 +132,6 @@ "max_det", "vid_stride", "line_width", - "workspace", "nbs", "save_period", } diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index b73ecf5c100..9c32e0ac5d4 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -675,9 +675,7 @@ def export_engine(self, prefix=colorstr("TensorRT:")): builder = trt.Builder(logger) config = builder.create_builder_config() - config.max_workspace_size = self.args.workspace * 1 << 30 - # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice - + config.max_workspace_size = int(self.args.workspace * (1 << 30)) flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger)