Skip to content

Commit

Permalink
Improved float workspace arg for TRT exports (ultralytics#9407)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
zldrobit and glenn-jocher authored Mar 29, 2024
1 parent 4a7ccba commit 03d0ffd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
3 changes: 1 addition & 2 deletions ultralytics/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
"""

# Define keys for arg type checks
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
CFG_FRACTION_KEYS = {
"dropout",
"iou",
Expand Down Expand Up @@ -132,7 +132,6 @@
"max_det",
"vid_stride",
"line_width",
"workspace",
"nbs",
"save_period",
}
Expand Down
4 changes: 1 addition & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,7 @@ def export_engine(self, prefix=colorstr("TensorRT:")):

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = self.args.workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice

config.max_workspace_size = int(self.args.workspace * (1 << 30))
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
Expand Down

0 comments on commit 03d0ffd

Please sign in to comment.