diff --git a/downloader.py b/downloader.py index 8c0fbe05..9fa5020c 100755 --- a/downloader.py +++ b/downloader.py @@ -12,16 +12,18 @@ from requests.exceptions import ConnectionError, ReadTimeout, TooManyRedirects, MissingSchema, InvalidURL +from functools import partial + parser = argparse.ArgumentParser(description='ImageNet image scraper') parser.add_argument('-scrape_only_flickr', default=True, type=lambda x: (str(x).lower() == 'true')) -parser.add_argument('-number_of_classes', default = 10, type=int) -parser.add_argument('-images_per_class', default = 10, type=int) -parser.add_argument('-data_root', default='' , type=str) -parser.add_argument('-use_class_list', default=False,type=lambda x: (str(x).lower() == 'true')) +parser.add_argument('-number_of_classes', default=10, type=int) +parser.add_argument('-images_per_class', default=10, type=int) +parser.add_argument('-data_root', default='', type=str) +parser.add_argument('-use_class_list', default=False, type=lambda x: (str(x).lower() == 'true')) parser.add_argument('-class_list', default=[], nargs='*') -parser.add_argument('-debug', default=False,type=lambda x: (str(x).lower() == 'true')) +parser.add_argument('-debug', default=False, type=lambda x: (str(x).lower() == 'true')) -parser.add_argument('-multiprocessing_workers', default = 8, type=int) +parser.add_argument('-multiprocessing_workers', default=8, type=int) args, args_other = parser.parse_known_args() @@ -36,7 +38,6 @@ logging.error(f'folder {args.data_root} does not exist! please provide existing folder in -data_root arg!') exit() - IMAGENET_API_WNID_TO_URLS = lambda wnid: f'http://www.image-net.org/api/imagenet.synset.geturls?wnid={wnid}' current_folder = os.path.dirname(os.path.realpath(__file__)) @@ -52,11 +53,11 @@ classes_to_scrape = [] if args.use_class_list == True: - for item in args.class_list: - classes_to_scrape.append(item) - if item not in class_info_dict: - logging.error(f'Class {item} not found in ImageNete') - exit() + for item in args.class_list: + classes_to_scrape.append(item) + if item not in class_info_dict: + logging.error(f'Class {item} not found in ImageNete') + exit() elif args.use_class_list == False: potential_class_pool = [] @@ -70,24 +71,23 @@ potential_class_pool.append(key) if (len(potential_class_pool) < args.number_of_classes): - logging.error(f"With {args.images_per_class} images per class there are {len(potential_class_pool)} to choose from.") + logging.error( + f"With {args.images_per_class} images per class there are {len(potential_class_pool)} to choose from.") logging.error(f"Decrease number of classes or decrease images per class.") exit() - picked_classes_idxes = np.random.choice(len(potential_class_pool), args.number_of_classes, replace = False) + picked_classes_idxes = np.random.choice(len(potential_class_pool), args.number_of_classes, replace=False) for idx in picked_classes_idxes: classes_to_scrape.append(potential_class_pool[idx]) - print("Picked the following clases:") -print([ class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape ]) +print([class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape]) imagenet_images_folder = os.path.join(args.data_root, 'imagenet_images') if not os.path.isdir(imagenet_images_folder): os.mkdir(imagenet_images_folder) - scraping_stats = dict( all=dict( tried=0, @@ -106,26 +106,27 @@ ) ) + def add_debug_csv_row(row): with open('stats.csv', "a") as csv_f: csv_writer = csv.writer(csv_f, delimiter=",") csv_writer.writerow(row) + class MultiStats(): def __init__(self): - self.lock = Lock() self.stats = dict( all=dict( tried=Value('d', 0), - success=Value('d',0), - time_spent=Value('d',0), + success=Value('d', 0), + time_spent=Value('d', 0), ), is_flickr=dict( tried=Value('d', 0), - success=Value('d',0), - time_spent=Value('d',0), + success=Value('d', 0), + time_spent=Value('d', 0), ), not_flickr=dict( tried=Value('d', 0), @@ -133,6 +134,7 @@ def __init__(self): time_spent=Value('d', 0), ) ) + def inc(self, cls, stat, val): with self.lock: self.stats[cls][stat].value += val @@ -142,8 +144,8 @@ def get(self, cls, stat): ret = self.stats[cls][stat].value return ret -multi_stats = MultiStats() +multi_stats = MultiStats() if args.debug: row = [ @@ -159,6 +161,7 @@ def get(self, cls, stat): ] add_debug_csv_row(row) + def add_stats_to_debug_csv(): row = [ multi_stats.get('all', 'tried'), @@ -173,8 +176,8 @@ def add_stats_to_debug_csv(): ] add_debug_csv_row(row) -def print_stats(cls, print_func): +def print_stats(cls, print_func): actual_all_time_spent = time.time() - scraping_t_start.value processes_all_time_spent = multi_stats.get('all', 'time_spent') @@ -183,17 +186,18 @@ def print_stats(cls, print_func): else: actual_processes_ratio = actual_all_time_spent / processes_all_time_spent - #print(f"actual all time: {actual_all_time_spent} proc all time {processes_all_time_spent}") + # print(f"actual all time: {actual_all_time_spent} proc all time {processes_all_time_spent}") print_func(f'STATS For class {cls}:') print_func(f' tried {multi_stats.get(cls, "tried")} urls with' f' {multi_stats.get(cls, "success")} successes') if multi_stats.get(cls, "tried") > 0: - print_func(f'{100.0 * multi_stats.get(cls, "success")/multi_stats.get(cls, "tried")}% success rate for {cls} urls ') + print_func( + f'{100.0 * multi_stats.get(cls, "success") / multi_stats.get(cls, "tried")}% success rate for {cls} urls ') if multi_stats.get(cls, "success") > 0: - print_func(f'{multi_stats.get(cls,"time_spent") * actual_processes_ratio / multi_stats.get(cls,"success")} seconds spent per {cls} succesful image download') - + print_func( + f'{multi_stats.get(cls, "time_spent") * actual_processes_ratio / multi_stats.get(cls, "success")} seconds spent per {cls} succesful image download') lock = Lock() @@ -202,16 +206,15 @@ def print_stats(cls, print_func): class_folder = '' class_images = Value('d', 0) -def get_image(img_url): - #print(f'Processing {img_url}') +def get_image(class_folder, img_url): + # print(f'Processing {img_url}') - #time.sleep(3) + # time.sleep(3) if len(img_url) <= 1: return - cls_imgs = 0 with lock: cls_imgs = class_images.value @@ -237,11 +240,11 @@ def finish(status): multi_stats.inc(cls, 'time_spent', t_spent) multi_stats.inc('all', 'time_spent', t_spent) - multi_stats.inc(cls,'tried', 1) + multi_stats.inc(cls, 'tried', 1) multi_stats.inc('all', 'tried', 1) if status == 'success': - multi_stats.inc(cls,'success', 1) + multi_stats.inc(cls, 'success', 1) multi_stats.inc('all', 'success', 1) elif status == 'failure': @@ -251,7 +254,6 @@ def finish(status): exit() return - with lock: url_tries.value += 1 if url_tries.value % 250 == 0: @@ -263,7 +265,7 @@ def finish(status): add_stats_to_debug_csv() try: - img_resp = requests.get(img_url, timeout = 1) + img_resp = requests.get(img_url, timeout=1) except ConnectionError: logging.debug(f"Connection Error for url {img_url}") return finish('failure') @@ -314,26 +316,28 @@ def finish(status): return finish('success') -for class_wnid in classes_to_scrape: +if __name__ == '__main__': - class_name = class_info_dict[class_wnid]["class_name"] - print(f'Scraping images for class \"{class_name}\"') - url_urls = IMAGENET_API_WNID_TO_URLS(class_wnid) + for class_wnid in classes_to_scrape: - time.sleep(0.05) - resp = requests.get(url_urls) + class_name = class_info_dict[class_wnid]["class_name"] + print(f'Scraping images for class \"{class_name}\"') + url_urls = IMAGENET_API_WNID_TO_URLS(class_wnid) - class_folder = os.path.join(imagenet_images_folder, class_name) - if not os.path.exists(class_folder): - os.mkdir(class_folder) + time.sleep(0.05) + resp = requests.get(url_urls) - class_images.value = 0 + class_folder = os.path.join(imagenet_images_folder, class_name) + if not os.path.exists(class_folder): + os.mkdir(class_folder) - urls = [url.decode('utf-8') for url in resp.content.splitlines()] + class_images.value = 0 - #for url in urls: - # get_image(url) + urls = [url.decode('utf-8') for url in resp.content.splitlines()] - print(f"Multiprocessing workers: {args.multiprocessing_workers}") - with Pool(processes=args.multiprocessing_workers) as p: - p.map(get_image,urls) + # for url in urls: + # get_image(url) + part = partial(get_image, class_folder) + print(f"Multiprocessing workers: {args.multiprocessing_workers}") + with Pool(processes=args.multiprocessing_workers) as p: + p.map(part, urls)