Skip to content

Commit

Permalink
ultralytics 8.0.206 engine Trainer updates (ultralytics#6111)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: jamjamjon <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2023
1 parent 25bd3b9 commit f2f5ed2
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 34 deletions.
13 changes: 9 additions & 4 deletions .github/workflows/links.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO Continuous Integration (CI) GitHub Actions tests broken link checker
# Accept 429(Instagram, 'too many requests'), 999(LinkedIn, 'unknown status code'), Timeout(Twitter)
# Continuous Integration (CI) GitHub Actions tests broken link checker using https://github.com/lycheeverse/lychee
# Ignores the following status codes to reduce false positives:
# - 403(OpenVINO, 'forbidden')
# - 429(Instagram, 'too many requests')
# - 500(Zenodo, 'cached')
# - 502(Zenodo, 'bad gateway')
# - 999(LinkedIn, 'unknown status code')

name: Check Broken links

Expand Down Expand Up @@ -28,7 +33,7 @@ jobs:
timeout_minutes: 5
retry_wait_seconds: 60
max_attempts: 3
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com|fonts\.gstatic\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html'
command: lychee --accept 403,429,500,502,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com|fonts\.gstatic\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html'

- name: Test Markdown, HTML, YAML, Python and Notebook links with retry
if: github.event_name == 'workflow_dispatch'
Expand All @@ -37,4 +42,4 @@ jobs:
timeout_minutes: 5
retry_wait_seconds: 60
max_attempts: 3
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com|url\.com|fonts\.gstatic\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html' './**/*.yml' './**/*.yaml' './**/*.py' './**/*.ipynb'
command: lychee --accept 429,999 --exclude-loopback --exclude 'https?://(www\.)?(linkedin\.com|twitter\.com|instagram\.com|kaggle\.com|fonts\.gstatic\.com|url\.com)' --exclude-path '**/ci.yaml' --exclude-mail --github-token ${{ secrets.GITHUB_TOKEN }} './**/*.md' './**/*.html' './**/*.yml' './**/*.yaml' './**/*.py' './**/*.ipynb'
4 changes: 2 additions & 2 deletions docs/help/privacy.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ keywords: Ultralytics, Data Collection, User Privacy, Google Analytics, Sentry,

## Overview

Ultralytics is dedicated to the continuous enhancement of the user experience and the capabilities of our Python package, including the advanced YOLO models we develop. Our approach involves the gathering of anonymized usage statistics and crash reports, helping us identify opportunities for improvement and ensuring the reliability of our software. This transparency document outlines what data we collect, its purpose, and the choice you have regarding this data collection.
[Ultralytics](https://ultralytics.com) is dedicated to the continuous enhancement of the user experience and the capabilities of our Python package, including the advanced YOLO models we develop. Our approach involves the gathering of anonymized usage statistics and crash reports, helping us identify opportunities for improvement and ensuring the reliability of our software. This transparency document outlines what data we collect, its purpose, and the choice you have regarding this data collection.

## Anonymized Google Analytics

Expand Down Expand Up @@ -50,7 +50,7 @@ If the `sentry-sdk` Python package is pre-installed on your system a crash event
- **Crash Logs**: Detailed reports on the application's condition at the time of a crash, which are vital for our debugging efforts.
- **Error Messages**: We record error messages generated during the operation of our package to understand and resolve potential issues quickly.

To learn more about how Sentry handles data, please visit [Sentry Privacy Policy](https://sentry.io/privacy/).
To learn more about how Sentry handles data, please visit [Sentry's Privacy Policy](https://sentry.io/privacy/).

### How We Use This Data

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = '8.0.205'
__version__ = '8.0.206'

from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
Expand Down
7 changes: 2 additions & 5 deletions ultralytics/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS


class Model(nn.Module):
Expand Down Expand Up @@ -88,10 +87,8 @@ def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
return

# Load or create new YOLO model
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
if suffix in ('.yaml', '.yml'):
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(model).suffix in ('.yaml', '.yml'):
self._new(model, task)
else:
self._load(model, task)
Expand Down
34 changes: 16 additions & 18 deletions ultralytics/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
import torch
from torch import distributed as dist
from torch import nn, optim
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP

from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
Expand All @@ -43,7 +41,6 @@ class BaseTrainer:
Attributes:
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
Expand All @@ -62,6 +59,7 @@ class BaseTrainer:
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
resume (bool): Resume training from a checkpoint.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
Expand All @@ -84,7 +82,6 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
Expand All @@ -111,7 +108,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading

# Model and Dataset
self.model = self.args.model
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
Expand All @@ -124,6 +121,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):

self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
self.resume = False

# Optimization utils init
self.lf = None
Expand Down Expand Up @@ -236,9 +234,9 @@ def _setup_train(self, world_size):
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[RANK])
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])

# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
Expand Down Expand Up @@ -311,11 +309,7 @@ def _do_train(self, world_size=1):
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self._close_dataloader_mosaic()
self.train_loader.reset()

if RANK in (-1, 0):
Expand Down Expand Up @@ -395,7 +389,7 @@ def _do_train(self, world_size=1):
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors

# Early Stopping
if RANK != -1: # if DDP training
Expand Down Expand Up @@ -613,11 +607,15 @@ def resume_training(self, ckpt):
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()

def _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self.train_loader.dataset.close_mosaic(hyp=self.args)

def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Expand Down
8 changes: 8 additions & 0 deletions ultralytics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
return file


def check_model_file_from_stem(model='yolov8n'):
"""Return a model filename from a valid model stem."""
if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt
else:
return model


def check_file(file, suffix='', download=True, hard=True):
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
Expand Down
8 changes: 4 additions & 4 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ def scale_image(masks, im0_shape, ratio_pad=None):
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) # y, x
bottom, right = (int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1)))

if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
Expand Down Expand Up @@ -704,8 +704,8 @@ def scale_masks(masks, shape, padding=True):
if padding:
pad[0] /= 2
pad[1] /= 2
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) if padding else (0, 0) # y, x
bottom, right = (int(round(mh - pad[1] + 0.1)), int(round(mw - pad[0] + 0.1)))
masks = masks[..., top:bottom, left:right]

masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
Expand Down

0 comments on commit f2f5ed2

Please sign in to comment.