Source code for birdfsd_yolov5.model_utils.wandb_helpers

#!/usr/bin/env python
# coding: utf-8

import argparse
from pathlib import Path
from typing import Optional

import wandb


[docs]def upload_artifact(run_path: str, artifact_type: str, file_path: str, artifact_name: Optional[str] = None, aliases: Optional[list] = None) -> None: """Upload an artifact to an existing wandb run. Args: run_path (str): The wandb run path. artifact_type (str): The artifact type (e.g., dataset, model, etc.). file_path (str): The path to the file on the local disk. artifact_name (str): A human-readable name for the artifact. If None, the file's base name is used. """ if not artifact_name: artifact_name = Path(file_path).name entity, project, _id = run_path.split('/') with wandb.init(entity=entity, project=project, id=_id, resume='allow') as run: artifact = wandb.Artifact(artifact_name, artifact_type) artifact.add_file(file_path) if aliases: run.log_artifact(artifact, aliases=aliases) else: run.log_artifact(artifact)
[docs]def add_f1_score(run_path: str) -> None: """Adds the F1 score to the run. Args: run_path (str): The wandb run path. """ api = wandb.Api() run = api.run(run_path) p = run.summary['best/precision'] r = run.summary['best/recall'] f1_score = 2 * ((p * r) / (p + r)) run.summary.update({'best/f1-score': f1_score}) run.update()
def _get_all_runs_path(entity: str, project: str) -> list: """Returns a list of all runs for the given entity and project. Args: entity (str): The wandb entity name. project (str): The wandb project name in the selected entity. Returns: list: A list of all runs for the given project. """ api = wandb.Api() runs = api.runs(f'{entity}/{project}') return ['/'.join(x.path) for x in runs] def _opts() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( '--add-f1-score-by-run-path', help='The wandb run path to update with the calculated F1-score', type=str) return parser.parse_args() if __name__ == '__main__': args = _opts() if args.add_f1_score_by_run_path: add_f1_score(args.add_f1_score_by_run_path)