#!/usr/bin/env python
# coding: utf-8
import argparse
import collections
import imghdr
import json
import os
import random
import shutil
import tarfile
from datetime import datetime, timedelta
from glob import glob
from pathlib import Path
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ray
import requests
import seaborn as sns
from dotenv import load_dotenv
from loguru import logger
from requests.structures import CaseInsensitiveDict
from tqdm import tqdm
from birdfsd_yolov5.model_utils import (handlers, mongodb_helper, s3_helper,
utils)
[docs]class FailedToParseImageURL(Exception):
pass
[docs]class JSON2YOLO:
"""Converts the output of a Label-studio project to a YOLO dataset.
The output is a folder with the following structure:
dataset-YOLO
├── classes.txt
├── dataset_config.yml
├── images
│ ├── train
│ └── val
└── labels
....├── train
....└── val
The output will also be stored in a tarball with the same name as the
output folder.
The tasks that failed to export for any reason, will be logged at the
end of the run.
"""
def __init__(self,
projects: str,
output_dir: str = 'dataset-YOLO',
only_tar_file: bool = False,
enable_s3: bool = True,
copy_data_from: str = None,
filter_underrepresented_cls: bool = False,
filter_cls_with_instances_under: Optional[int] = None,
get_tasks_with_api: bool = False,
force_update: bool = False):
self.projects = projects
self.output_dir = str(Path(output_dir).absolute())
self.only_tar_file = only_tar_file
self.enable_s3 = enable_s3
self.copy_data_from = copy_data_from
self.filter_underrepresented_cls = filter_underrepresented_cls
self.filter_cls_with_instances_under = filter_cls_with_instances_under
self.get_tasks_with_api = get_tasks_with_api
self.force_update = force_update
self.imgs_dir = f'{self.output_dir}/ls_images'
self.labels_dir = f'{self.output_dir}/ls_labels'
self.classes = None
self.tasks_not_exported = []
[docs] @staticmethod
def bbox_ls_to_yolo(x: float, y: float, width: float,
height: float) -> tuple:
"""From label-studio's xywh to yolov5's xywh.
Converts a bounding box from the format used by the labelme tool to
the format used by the yolo tool.
Args:
x: The x coordinate of the top left corner of the bounding box.
y: The y coordinate of the top left corner of the bounding box.
width: The width of the bounding box.
height: The height of the bounding box.
Returns:
tuple: A tuple containing the x, y, width and height of the
bounding box in the format used by the yolov5.
"""
x = (x + width / 2) / 100
y = (y + height / 2) / 100
w = width / 100
h = height / 100
return x, y, w, h
[docs] def get_data(self, excluded_labels) -> list:
"""This function is used to get data from the database.
Returns:
list: A list of data.
"""
@ray.remote
def iter_projects(proj_id):
return mongodb_helper.get_tasks_from_mongodb(proj_id,
dump=False,
json_min=True)
@ray.remote
def iter_projects_api(proj_id):
headers = CaseInsensitiveDict()
headers['Content-type'] = 'application/json'
headers['Authorization'] = f'Token {os.environ["TOKEN"]}'
ls_host = os.environ["LS_HOST"]
q = 'exportType=JSON&download_all_tasks=true'
proj_tasks = []
url = f'{ls_host}/api/projects/{proj_id}/export?{q}'
resp = requests.get(url, headers=headers)
proj_tasks.append(resp.json())
return proj_tasks
if self.projects:
project_ids = self.projects.split(',')
else:
project_ids = utils.get_project_ids_str().split(',')
futures = []
for project_id in project_ids:
if self.get_tasks_with_api:
futures.append(iter_projects_api.remote(project_id))
else:
futures.append(iter_projects.remote(project_id))
data = []
for future in tqdm(futures, desc='Projects'):
data.append(ray.get(future))
data = sum(data, [])
labels = []
for entry in data:
try:
labels.append([
label['rectanglelabels'][0] for label in entry['label']
][0])
except KeyError as e:
logger.warning(f'Current entry raised KeyError {e}! '
f'Ignoring entry: {entry}')
unique, counts = np.unique(labels, return_counts=True)
if (self.filter_underrepresented_cls
or self.filter_cls_with_instances_under):
if self.filter_underrepresented_cls:
min_instances = np.median(counts)
else:
min_instances = self.filter_cls_with_instances_under
self.classes = sorted([
label for label, count in zip(unique, counts)
if label not in excluded_labels and count >= min_instances
])
else:
self.classes = sorted(
[label for label in unique if label not in excluded_labels])
logger.debug(f'Number of classes: {len(self.classes)}')
logger.debug(f'Classes: {self.classes}')
Path(self.output_dir).mkdir(exist_ok=True)
with open(f'{self.output_dir}/classes.txt', 'w') as f:
for class_ in self.classes:
f.write(f'{class_}\n')
return data
[docs] def convert_to_yolo(self, task: dict) -> Optional[list]:
"""Convert the task to YOLO format.
Args:
task (dict): The task to be converted.
Returns:
list: The labels in the task.
Raises:
FailedToParseImageURL: If the image URL is not valid.
TypeError: If the image URL is not valid.
"""
if self.copy_data_from or self.enable_s3:
img = task['image']
if img.startswith('s3://'):
object_name = img.split('s3://data/')[-1]
elif img.startswith('http'):
object_url = img.split('?')[0]
object_name = '/'.join(Path(object_url).parts[-2:])
else:
raise FailedToParseImageURL(img)
cur_img_name = Path(object_name).name
else:
if 's3://' in task['image'] and not self.enable_s3:
raise TypeError('You need to pass the flag `--enable-s3` '
'for S3 objects!')
object_name = None
cur_img_name = Path(task['image']).name
cur_img_path = f'{self.imgs_dir}/{cur_img_name}'
cur_label_path = f'{self.labels_dir}/{Path(cur_img_name).stem}.txt'
if self.copy_data_from:
shutil.copy(f'{self.copy_data_from}/{object_name}', cur_img_path)
else:
if self.enable_s3:
img_url = s3_helper.S3().client.presigned_get_object(
'data', object_name, expires=timedelta(hours=6))
else:
img_url = task['image']
r = requests.get(img_url)
if '<Error>' in r.text:
logger.error(
f'Could not download the image `{img_url}`! Skipping...')
return
with open(cur_img_path, 'wb') as f:
f.write(r.content)
try:
valid_image = imghdr.what(cur_img_path)
if not valid_image:
logger.error(f'{cur_img_path} is not valid (task'
f' id: {task["id"]})! Skipping...')
Path(cur_img_path).unlink()
return
except FileNotFoundError:
logger.error(
f'Could not validate {cur_img_path} from {task["id"]}! '
'Skipping...')
return
label_names = []
with open(cur_label_path, 'w') as f:
try:
labels = task['label']
except KeyError:
self.tasks_not_exported.append(task)
logger.error(f'>>>>>>>>>> CORRUPTED TASK: {task}')
f.close()
Path(cur_label_path).unlink()
try:
Path(cur_img_path).unlink()
except FileNotFoundError:
pass
return
for label in labels:
if label['rectanglelabels'][0] not in self.classes:
f.close()
Path(cur_label_path).unlink()
Path(cur_img_path).unlink()
return
label_names.append(label['rectanglelabels'][0])
x, y, width, height = [
v for k, v in label.items()
if k in ['x', 'y', 'width', 'height']
]
x, y, width, height = self.bbox_ls_to_yolo(x, y, width, height)
categories = list(enumerate(self.classes)) # noqa
label_idx = [
k[0] for k in categories
if k[1] == label['rectanglelabels'][0]
][0]
f.write(f'{label_idx} {x} {y} {width} {height}')
f.write('\n')
return label_names
[docs] def split_data(self) -> None:
"""Split the data into train and validation sets.
Returns:
None
"""
for subdir in [
'images/train', 'labels/train', 'images/val', 'labels/val'
]:
Path(f'{self.output_dir}/{subdir}').mkdir(parents=True,
exist_ok=True)
images = sorted(glob(f'{self.output_dir}/ls_images/*'))
labels = sorted(glob(f'{self.output_dir}/ls_labels/*'))
pairs = [(x, y) for x, y in zip(images, labels)]
len_ = len(images)
train_len = round(len_ * 0.8)
random.shuffle(pairs)
train, val = pairs[:train_len], pairs[train_len:]
for split, split_str in zip([train, val], ['train', 'val']):
for n, dtype in zip([0, 1], ['images', 'labels']):
base_subdir = f'{self.output_dir}/{dtype}/{split_str}'
_ = [
shutil.copy2(x[n], f'{base_subdir}/{Path(x[n]).name}')
for x in split
]
return
[docs] def plot_results(self, results: list) -> None:
"""Plots the results of the classification.
Args:
results (list): The results of the classification.
Returns:
None
"""
matplotlib.use('Agg')
plt.subplots(figsize=(12, 8), dpi=300)
plt.xticks(rotation=90)
sns.histplot(data=results, kde=True)
plt.title('Distribution Of Classes With A Kernel Density Estimate')
plt.savefig(f'{self.output_dir}/bar.jpg', bbox_inches='tight')
plt.cla()
data_count = dict(collections.Counter(results))
df = pd.DataFrame(data_count.items(), columns=['label', 'count'])
df = df.groupby('label').sum().reset_index()
ax = sns.barplot(data=df,
x='label',
y='count',
hue='label',
dodge=False)
plt.xticks(rotation=90)
plt.title('Instances Per Class')
ax.get_legend().remove()
plt.savefig(f'{self.output_dir}/hist.jpg', bbox_inches='tight')
return
[docs] def run(self) -> None:
"""Runs the preprocessing pipeline.
This method is used to run main preprocessing pipeline and convert
the data to the yolov5 format.
Returns:
None
Raises:
BucketDoesNotExist: If the dataset S3 bucket does not exist.
"""
@ray.remote
def iter_convert_to_yolo(t):
return self.convert_to_yolo(t)
s3_client = s3_helper.S3().client
handlers.catch_keyboard_interrupt()
random.seed(8)
excluded_labels = os.getenv('EXCLUDE_LABELS')
if excluded_labels:
excluded_labels = excluded_labels.split(',')
else:
excluded_labels = []
tasks = self.get_data(excluded_labels)
Path(self.imgs_dir).mkdir(parents=True, exist_ok=True)
Path(self.labels_dir).mkdir(parents=True, exist_ok=True)
futures = []
for task in tasks:
futures.append(iter_convert_to_yolo.remote(task))
results = []
for future in tqdm(futures, desc='Tasks'):
results.append(ray.get(future))
if results:
results = [x for x in results if x]
results = sum(results, [])
self.plot_results(results)
if self.tasks_not_exported:
logger.error(f'Corrupted tasks: {self.tasks_not_exported}')
assert len(glob(f'{self.output_dir}/images/*')) == len(
glob(f'{self.output_dir}/labels/*'))
self.split_data()
shutil.rmtree(f'{self.output_dir}/ls_images', ignore_errors=True)
shutil.rmtree(f'{self.output_dir}/ls_labels', ignore_errors=True)
d = {
'path': f'{self.output_dir}',
'train': 'images/train',
'val': 'images/val',
'test': '',
'nc': len(self.classes),
'names': self.classes
}
with open(f'{self.output_dir}/dataset_config.yml', 'w') as f:
for k, v in d.items():
f.write(f'{k}: {v}\n')
utils._tasks_data(f'{self.output_dir}/tasks.json')
with open(f'{self.output_dir}/classes.json', 'w') as f:
classes_json = {
k: v
for k, v in utils.get_labels_count().items()
if k not in excluded_labels
}
if self.filter_cls_with_instances_under:
classes_json = {
k: v
for k, v in classes_json.items()
if v > self.filter_cls_with_instances_under
}
json.dump(classes_json, f, indent=4)
folder_name = Path(self.output_dir).name
ts = datetime.now().strftime('%m-%d-%Y_%H.%M.%S')
dataset_name = f'{folder_name}-{ts}.tar'
with tarfile.open(dataset_name, 'w') as tar:
tar.add(self.output_dir, folder_name)
if self.only_tar_file:
shutil.rmtree(self.output_dir, ignore_errors=True)
if self.enable_s3:
if not s3_client.bucket_exists('dataset'):
raise s3_helper.BucketDoesNotExist(
'Bucket `dataset` does not exist!')
upload_dataset = False
objs = list(s3_client.list_objects('dataset'))
if objs:
latest_ts = max(
[o.last_modified for o in objs if o.last_modified])
latest_obj = [o for o in objs
if o.last_modified == latest_ts][0]
if latest_obj.size != Path(
dataset_name).stat().st_size or self.force_update:
upload_dataset = True
else:
upload_dataset = True
if upload_dataset:
if self.copy_data_from:
logger.debug('Copying the dataset to the bucket...')
ds_path = f'{Path(self.copy_data_from).parent}/dataset'
shutil.copy(dataset_name, f'{ds_path}/{dataset_name}')
else:
logger.info('Uploading the dataset...')
s3_client.fput_object('dataset', dataset_name,
dataset_name)
return
if __name__ == '__main__':
load_dotenv()
parser = argparse.ArgumentParser()
parser.add_argument(
'-p',
'--projects',
help='Comma-seperated projects ID. If empty, it will select all '
'projects',
type=str)
parser.add_argument('-o',
'--output-dir',
help='Path to the output directory',
type=str,
default='dataset-YOLO')
parser.add_argument('--only-tar-file',
help='Only output a TAR file',
action="store_true")
parser.add_argument('--enable-s3',
help='Upload the output to an S3 bucket',
action="store_true")
parser.add_argument(
'--copy-data-from',
help='If running on the same host serving the S3 objects, you can '
'use this option to specify a path to copy the data from '
'(i.e., the local path to the S3 bucket where the data is '
'stored) instead of downloading it',
type=str)
parser.add_argument('--filter-underrepresented-cls',
help='Only include classes with instances equal or '
'above the overall median',
action="store_true")
parser.add_argument(
'--filter-cls-with-instances-under',
help='Remove the class from the dataset if the annotation instances '
'is lower than n',
type=int)
parser.add_argument('--get-tasks-with-api',
help='Use label-studio API to get tasks data',
action="store_true")
parser.add_argument(
'--force-update',
help='Update the dataset even when it appears to be identical to the '
'latest dataset',
action="store_true")
args = parser.parse_args()
json2yolo = JSON2YOLO(
projects=args.projects,
output_dir=args.output_dir,
only_tar_file=args.only_tar_file,
enable_s3=args.enable_s3,
copy_data_from=args.copy_data_from,
filter_underrepresented_cls=args.filter_underrepresented_cls,
filter_cls_with_instances_under=args.filter_cls_with_instances_under,
get_tasks_with_api=args.get_tasks_with_api,
force_update=args.force_update)
json2yolo.run()