Source code for birdfsd_yolov5.prediction.auto_prediction_schedule

#!/usr/bin/env python
# coding: utf-8

import argparse
import json
import sys
from pathlib import Path
from typing import Optional

from dotenv import load_dotenv
from loguru import logger

from birdfsd_yolov5.model_utils import download_weights, handlers, utils
from birdfsd_yolov5.prediction import predict


def _opts() -> argparse.Namespace:
    """Parse command line arguments.

    Returns:
        argparse.Namespace: Namespace object containing the parsed arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--opts-file',
                        type=str,
                        help='JSON file with predict.py options')
    parser.add_argument(
        '--show-opts',
        action='store_true',
        help='Show the current prediction function configuration then exit')
    return parser.parse_args()


[docs]def auto_prediction_pipeline(opts_file: Optional[str] = None, show_opts: bool = False) -> None: """A pipeline to run the prediction module. A prediction pipeline that is intended to be used as systemctl service or inside a GitHub actions workflow. """ logs_file = utils.add_logger(__file__) handlers.catch_keyboard_interrupt() if opts_file: logger.debug(f'Loading options from file: `{opts_file}`...') with open(opts_file) as j: OPTS = json.load(j) else: logger.debug('Using default options...') OPTS = { 'weights': '', 'project_ids': None, # None will return all projects 'tasks_range': None, 'predict_all': True, 'one_task': None, 'model_version': 'latest', 'multithreading': True, 'delete_if_no_predictions': False, 'if_empty_apply_label': 'no animal', 'get_tasks_with_api': False, 'verbose': True } logger.debug(f'OPTS: {OPTS}') if not OPTS['weights'] and OPTS['model_version'] == 'latest': dmw = download_weights.DownloadModelWeights(OPTS['model_version']) skip_download = False if args.show_opts: skip_download = True weights, _weights_url, weights_model_ver = dmw.get_weights( skip_download=skip_download) logger.info(f'Downloaded weights to {weights}') OPTS['weights'] = weights OPTS['model_version'] = weights_model_ver else: if OPTS['model_version'] == 'latest': raise sys.exit( 'Need to specify model version if loaded from a file path!') if show_opts: print(json.dumps(OPTS, indent=4)) sys.exit(0) p = predict.Predict(**OPTS) p.apply_predictions() try: Path('best.pt').unlink() except FileNotFoundError: pass utils.upload_logs(logs_file) return
if __name__ == '__main__': load_dotenv() args = _opts() auto_prediction_pipeline(opts_file=args.opts_file, show_opts=args.show_opts)