You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
3 weeks ago
|
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
|
||
|
"""Callback utils."""
|
||
|
|
||
|
import threading
|
||
|
|
||
|
|
||
|
class Callbacks:
|
||
|
"""Handles all registered callbacks for YOLOv5 Hooks."""
|
||
|
|
||
|
def __init__(self):
|
||
|
"""Initializes a Callbacks object to manage registered YOLOv5 training event hooks."""
|
||
|
self._callbacks = {
|
||
|
"on_pretrain_routine_start": [],
|
||
|
"on_pretrain_routine_end": [],
|
||
|
"on_train_start": [],
|
||
|
"on_train_epoch_start": [],
|
||
|
"on_train_batch_start": [],
|
||
|
"optimizer_step": [],
|
||
|
"on_before_zero_grad": [],
|
||
|
"on_train_batch_end": [],
|
||
|
"on_train_epoch_end": [],
|
||
|
"on_val_start": [],
|
||
|
"on_val_batch_start": [],
|
||
|
"on_val_image_end": [],
|
||
|
"on_val_batch_end": [],
|
||
|
"on_val_end": [],
|
||
|
"on_fit_epoch_end": [], # fit = train + val
|
||
|
"on_model_save": [],
|
||
|
"on_train_end": [],
|
||
|
"on_params_update": [],
|
||
|
"teardown": [],
|
||
|
}
|
||
|
self.stop_training = False # set True to interrupt training
|
||
|
|
||
|
def register_action(self, hook, name="", callback=None):
|
||
|
"""
|
||
|
Register a new action to a callback hook.
|
||
|
|
||
|
Args:
|
||
|
hook: The callback hook name to register the action to
|
||
|
name: The name of the action for later reference
|
||
|
callback: The callback to fire
|
||
|
"""
|
||
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||
|
assert callable(callback), f"callback '{callback}' is not callable"
|
||
|
self._callbacks[hook].append({"name": name, "callback": callback})
|
||
|
|
||
|
def get_registered_actions(self, hook=None):
|
||
|
"""
|
||
|
Returns all the registered actions by callback hook.
|
||
|
|
||
|
Args:
|
||
|
hook: The name of the hook to check, defaults to all
|
||
|
"""
|
||
|
return self._callbacks[hook] if hook else self._callbacks
|
||
|
|
||
|
def run(self, hook, *args, thread=False, **kwargs):
|
||
|
"""
|
||
|
Loop through the registered actions and fire all callbacks on main thread.
|
||
|
|
||
|
Args:
|
||
|
hook: The name of the hook to check, defaults to all
|
||
|
args: Arguments to receive from YOLOv5
|
||
|
thread: (boolean) Run callbacks in daemon thread
|
||
|
kwargs: Keyword Arguments to receive from YOLOv5
|
||
|
"""
|
||
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||
|
for logger in self._callbacks[hook]:
|
||
|
if thread:
|
||
|
threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start()
|
||
|
else:
|
||
|
logger["callback"](*args, **kwargs)
|