Skip to content

Commit

Permalink
Clean up unused imgsz (ultralytics#7771)
Browse files Browse the repository at this point in the history
  • Loading branch information
Laughing-q authored Jan 23, 2024
1 parent f56dd0f commit 67ae86f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
11 changes: 5 additions & 6 deletions ultralytics/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,12 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor


def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
def load_inference_source(source=None, vid_stride=1, buffer=False):
"""
Loads an inference source for object detection and applies necessary transformations.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Expand All @@ -172,13 +171,13 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
elif in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer)
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz)
dataset = LoadScreenshots(source)
elif from_img:
dataset = LoadPilAndNumpy(source, imgsz=imgsz)
dataset = LoadPilAndNumpy(source)
else:
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
dataset = LoadImages(source, vid_stride=vid_stride)

# Attach source types to the dataset
setattr(dataset, "source_type", source_type)
Expand Down
16 changes: 4 additions & 12 deletions ultralytics/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class LoadStreams:
Attributes:
sources (str): The source input paths or URLs for the video streams.
imgsz (int): The image size for processing, defaults to 640.
vid_stride (int): Video frame-rate stride, defaults to 1.
buffer (bool): Whether to buffer input streams, defaults to False.
running (bool): Flag to indicate if the streaming thread is running.
Expand All @@ -60,13 +59,12 @@ class LoadStreams:
__len__: Return the length of the sources object.
"""

def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False):
def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
self.mode = "stream"
self.imgsz = imgsz
self.vid_stride = vid_stride # video frame-rate stride

sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
Expand Down Expand Up @@ -193,7 +191,6 @@ class LoadScreenshots:
Attributes:
source (str): The source input indicating which screen to capture.
imgsz (int): The image size for processing, defaults to 640.
screen (int): The screen number to capture.
left (int): The left coordinate for screen capture area.
top (int): The top coordinate for screen capture area.
Expand All @@ -210,7 +207,7 @@ class LoadScreenshots:
__next__: Captures the next screenshot and returns it.
"""

def __init__(self, source, imgsz=640):
def __init__(self, source):
"""Source = [screen_number left top width height] (pixels)."""
check_requirements("mss")
import mss # noqa
Expand All @@ -223,7 +220,6 @@ def __init__(self, source, imgsz=640):
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.imgsz = imgsz
self.mode = "stream"
self.frame = 0
self.sct = mss.mss()
Expand Down Expand Up @@ -258,7 +254,6 @@ class LoadImages:
various formats, including single image files, video files, and lists of image and video paths.
Attributes:
imgsz (int): Image size, defaults to 640.
files (list): List of image and video file paths.
nf (int): Total number of files (images and videos).
video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
Expand All @@ -274,7 +269,7 @@ class LoadImages:
_new_video(path): Create a new cv2.VideoCapture object for a given video path.
"""

def __init__(self, path, imgsz=640, vid_stride=1):
def __init__(self, path, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
Expand All @@ -298,7 +293,6 @@ def __init__(self, path, imgsz=640, vid_stride=1):
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)

self.imgsz = imgsz
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
Expand Down Expand Up @@ -377,7 +371,6 @@ class LoadPilAndNumpy:
Attributes:
paths (list): List of image paths or autogenerated filenames.
im0 (list): List of images stored as Numpy arrays.
imgsz (int): Image size, defaults to 640.
mode (str): Type of data being processed, defaults to 'image'.
bs (int): Batch size, equivalent to the length of `im0`.
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
Expand All @@ -386,13 +379,12 @@ class LoadPilAndNumpy:
_single_check(im): Validate and format a single image to a Numpy array.
"""

def __init__(self, im0, imgsz=640):
def __init__(self, im0):
"""Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list):
im0 = [im0]
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
self.im0 = [self._single_check(im) for im in im0]
self.imgsz = imgsz
self.mode = "image"
# Generate fake paths
self.bs = len(self.im0)
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/engine/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def setup_source(self, source):
else None
)
self.dataset = load_inference_source(
source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
)
self.source_type = self.dataset.source_type
if not getattr(self, "stream", True) and (
Expand Down

0 comments on commit 67ae86f

Please sign in to comment.