Source code for birdfsd_yolov5.model_utils.s3_helper

#!/usr/bin/env python
# coding: utf-8

import mimetypes
import os
import sys
import textwrap
from datetime import timedelta
from pathlib import Path
from typing import Union

from dotenv import load_dotenv
from minio import Minio
from minio.datatypes import Object
from minio.helpers import ObjectWriteResult

from birdfsd_yolov5.model_utils import utils


[docs]class BucketDoesNotExist(Exception): """Raised when the request bucket does not exist."""
[docs]class S3: """S3 client that uses MinIO as a backend.""" def __init__(self): """Initialize the Minio client. Returns: None """ self.client = Minio(os.environ['S3_ENDPOINT'], access_key=os.environ['S3_ACCESS_KEY'], secret_key=os.environ['S3_SECRET_KEY'], region=os.environ['S3_REGION'])
[docs] def upload(self, bucket_name: str, file_path: str, public: bool = False, scheme: str = 'https', dest: str = None) -> Union[ObjectWriteResult, str]: """Uploads a file to an S3 bucket. Args: bucket_name (str): The name of the bucket to upload to. file_path (str): The path to the file to upload. public (bool): True if the file is uploaded to a publicly accessible bucket. scheme (str): The scheme to use for the URL. dest (str): The destination path for the file. Returns: str: The URL of the uploaded file. """ file = Path(file_path) if not dest: dest = file.name content_type = mimetypes.guess_type(file_path) if content_type[0]: content_type = content_type[0] else: content_type = 'application/octet-stream' res = self.client.fput_object(bucket_name=bucket_name, object_name=dest, file_path=file, content_type=content_type) if public: if 'http' not in os.environ[ "S3_ENDPOINT"] or 'https' not in os.environ["S3_ENDPOINT"]: domain = f'{scheme}://{os.environ["S3_ENDPOINT"]}' else: domain = os.environ["S3_ENDPOINT"] if dest: return f'{domain}/{bucket_name}/{dest}' else: return f'{domain}/{bucket_name}/{file.name}' else: return res
[docs] def download(self, bucket_name: str, object_name: str, dest: str = None) -> Object: """Downloads an object from the bucket. Args: bucket_name (str): The name of the bucket. object_name (str): The name of the object. dest (str): The destination path to download the object. Returns: str: The destination path where the object was downloaded. """ if not dest: dest = object_name return self.client.fget_object(bucket_name, object_name, dest)
[docs] def get_model_weights(self, model_version: str = 'latest') -> str: """Get the model weights from the model bucket. Args: model_version: The version of the model to get. If 'latest', the latest version is returned. If a string, the model with that version is returned. If None, the latest version is returned. Returns: str: The name of the model object in the model bucket. """ objects = list(self.client.list_objects('model')) if model_version == 'latest': latest_ts = max([obj.last_modified for obj in objects]) latest_model_object = [ obj for obj in objects if obj.last_modified == latest_ts ][0] return latest_model_object.object_name else: for obj in objects: if obj.object_name.endswith(model_version): return obj.object_name
[docs] def get_dataset(self, object_name: str = None) -> str: """Retrieves the latest dataset from the S3 bucket. Args: object_name (str): The name of the object to retrieve. If not provided, the latest object will be retrieved. Returns: str: The name of the object retrieved. """ if not object_name: objs = list(self.client.list_objects('dataset')) latest_ts = max([o.last_modified for o in objs if o.last_modified]) latest_obj = [o for o in objs if o.last_modified == latest_ts][0] object_name = latest_obj.object_name presigned_url = self.client.presigned_get_object( 'dataset', object_name, expires=timedelta(hours=6)) utils.requests_download(presigned_url, object_name) return object_name
if __name__ == '__main__': if '--help' in sys.argv or '-h' in sys.argv: message = textwrap.dedent('''\ - Configuration: $ aws configure # AWS Access Key ID: ... # AWS Secret Access Key: ... # Default region name: us-east-1 # Default output format: <hit ENTER> $ aws configure set default.s3.signature_version s3v4''') print(message) if '--download-dataset' in sys.argv: load_dotenv() _ = S3().get_dataset()