Source code for birdfsd_yolov5.model_utils.update_run_cfg

#!/usr/bin/env python
# coding: utf-8

import argparse
import multiprocessing
import os
import platform
from pathlib import Path

import pynvml as nv  # noqa
import torch
import torchvision
import wandb


[docs]def update_run_cfg(run_path, dataset_dir, dataset_name) -> None: """Adds additional meta data to a finished wandb run. This function uploads the classes.txt and hist.jpg files to a finished wandb run. It also updates the run's config with the dataset name, base ML framework version, and system hardware. Returns: None """ api = wandb.Api() run = api.run(run_path) cwd = os.getcwd() os.chdir(dataset_dir) run.upload_file('classes.txt') run.upload_file('hist.jpg') os.chdir(cwd) torch_version = torch.__version__ torchvision_version = torchvision.__version__ python_version = platform.python_version() cuda_version = os.popen('nvcc --version | grep release').read().split( ', ')[1].split('release ')[1] version = { 'Python': python_version, 'CUDA': cuda_version, 'Torch': torch_version, 'Torchvision': torchvision_version } try: nv.nvmlInit() gpu_count = nv.nvmlDeviceGetCount() gpu_type = [ nv.nvmlDeviceGetName(nv.nvmlDeviceGetHandleByIndex(i)).decode() for i in range(gpu_count) ] system_hardware = { 'cpu_count': multiprocessing.cpu_count(), 'gpu_count': gpu_count, 'gpu_type': ', '.join(gpu_type), 'nvidia_driver_version': nv.nvmlSystemGetDriverVersion().decode() } except nv.NVMLError: system_hardware = {'cpu_count': multiprocessing.cpu_count()} dict_to_add = { 'dataset_name': Path(dataset_name).name, 'base_ml_framework': version, 'system_hardware': system_hardware } run.config.update(dict_to_add) print(dict_to_add) run.update() return
def _opts(): parser = argparse.ArgumentParser() parser.add_argument( '-p', '--run-path', help='Path to the W&B run (i.e., `<entity>/<project>/<run_id>`)', type=str, required=True) parser.add_argument('-d', '--dataset-dir', help='Path to the dataset directory', type=str, default='dataset-YOLO') parser.add_argument('-D', '--dataset-name', help='Name of the dataset TAR file', type=str, required=True) return parser.parse_args() if __name__ == '__main__': args = _opts() update_run_cfg(run_path=args.run_path, dataset_dir=args.dataset_dir, dataset_name=args.dataset_name)