Source code for birdfsd_yolov5.api.model_utils

# -*- coding: utf-8 -*-
"""This module is used to get the latest model weights and information."""

from pathlib import Path
from typing import Optional, Tuple

import torch
from minio import Minio

from birdfsd_yolov5.model_utils.mongodb_helper import mongodb_db


[docs]def get_latest_model_weights(s3_client: Minio, skip_download: bool = False ) -> Tuple[str, str, str]: """Get the latest model weights from the model collection in mongodb. Args: s3_client (Minio): Minio S3 client object skip_download (bool): If True, skip downloading the model weights from S3. Returns: model_version: The version of the model weights model_name: The name of the model model_object_name: The name of the model weights file """ db = mongodb_db() col = db['model'] latest_model_ts = max(col.distinct('added_on')) model_document = db.model.find_one({'added_on': latest_model_ts}) model_version = model_document['version'] model_name = model_document['name'] model_object_name = f'{model_name}-v{model_version}.pt' if skip_download: return model_version, model_name, model_object_name _ = s3_client.fget_object('model', model_object_name, model_object_name) if not Path(model_object_name).exists(): raise AssertionError return model_version, model_name, model_object_name
[docs]def init_model( s3: Minio, use_weights: Optional[str] = None ) -> Tuple[str, str, str, torch.nn.Module]: """This function initializes the model. Args: s3 (Minio): Minio S3 client object. use_weights (str): Use this weights file instead of the latest model. Returns: model_version: The model version. model_name: The model name. model_weights: The model weights file name. model: The model object. """ if not use_weights: model_version, model_name, model_weights = get_latest_model_weights( s3, skip_download=True) else: model_version = Path(use_weights).stem model_name = Path(use_weights).stem model_weights = use_weights if not Path(model_weights).exists(): model_version, model_name, model_weights = get_latest_model_weights(s3) model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_weights) return model_version, model_name, model_weights, model
[docs]def model_info(version: str) -> dict: """Returns the model information for the specified version. Args: version (str): The version of the model to be returned. Returns: dict: A dictionary containing the model information. """ mdb = mongodb_db() col = mdb.model if version == 'latest': latest_model_ts = max(col.find().distinct('added_on')) model_obj = col.find({'added_on': latest_model_ts}).next() else: model_obj = col.find({'version': version}).next() model_obj.pop('_id') model_obj['added_on'] = str(model_obj['added_on']) model_obj['trained_on'] = str(model_obj['trained_on']) return model_obj