# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import hashlib import gzip import tarfile import zipfile import urllib.request from torch.utils.model_zoo import tqdm def generate_update_progress_barr(): """Generates a progress bar for tracking download progress.""" pbar = tqdm(total=None) def update_progress_bar(count, block_size, total_size): """Updates the progress bar based on the downloaded data size.""" if pbar.total is None and total_size: pbar.total = total_size progress_bytes = count * block_size pbar.update(progress_bytes - pbar.n) return update_progress_bar def compute_md5_checksum(fpath, chunk_size=1024 * 1024): """Calculates the MD5 checksum for a given file.""" md5 = hashlib.md5() with open(fpath, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) return md5.hexdigest() def verify_md5_checksum(fpath, md5): """Checks if the MD5 of a file matches the given checksum.""" return md5 == compute_md5_checksum(fpath) def validate_integrity(fpath, md5=None): """Checks the integrity of a file by verifying its existence and MD5 checksum.""" if not os.path.isfile(fpath): return False return md5 is None or verify_md5_checksum(fpath, md5) def download_url(url, root, filename=None, md5=None): """Download a file from a URL and save it in the specified directory.""" root = os.path.expanduser(root) filename = filename or os.path.basename(url) fpath = os.path.join(root, filename) os.makedirs(root, exist_ok=True) if validate_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) return try: print('Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) except (urllib.error.URLError, IOError) as e: if url.startswith('https'): url = url.replace('https:', 'http:') print('Failed download. Retrying with http.') urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) else: raise e if not validate_integrity(fpath, md5): raise RuntimeError("File not found or corrupted.") def list_dir(root, prefix=False): """List all directories at the specified root.""" root = os.path.expanduser(root) directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] return [os.path.join(root, d) for d in directories] if prefix else directories def list_files(root, suffix, prefix=False): """List all files with a specific suffix in the specified root.""" root = os.path.expanduser(root) files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)] return [os.path.join(root, f) for f in files] if prefix else files def fetch_file_google_drive(file_id, root, filename=None, md5=None): """Download a file from Google Drive and save it in the specified directory.""" url = "https://docs.google.com/uc?export=download" root = os.path.expanduser(root) filename = filename or file_id fpath = os.path.join(root, filename) os.makedirs(root, exist_ok=True) if os.path.isfile(fpath) and validate_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) return session = requests.Session() response = session.get(url, params={'id': file_id}, stream=True) token = _get_confirm_token(response) if token: params = {'id': file_id, 'confirm': token} response = session.get(url, params=params, stream=True) _store_response_content(response, fpath) def _get_confirm_token(response): """Extract the download token from Google Drive cookies.""" return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None) def _store_response_content(response, destination, chunk_size=32768): """Save the response content to a file in chunks.""" with open(destination, "wb") as f: pbar = tqdm(total=None) progress = 0 for chunk in response.iter_content(chunk_size): if chunk: # filter out keep-alive new chunks f.write(chunk) progress += len(chunk) pbar.update(progress - pbar.n) pbar.close() def extract_archive(from_path, to_path=None, remove_finished=False): """Extract an archive file (tar, zip, gz) to the specified path.""" if to_path is None: to_path = os.path.dirname(from_path) if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")): mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else '.xz' if from_path.endswith('.tar.xz') else '') with tarfile.open(from_path, mode) as tar: tar.extractall(path=to_path) elif from_path.endswith(".gz"): to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: out_f.write(zip_f.read()) elif from_path.endswith(".zip"): with zipfile.ZipFile(from_path, 'r') as z: z.extractall(to_path) else: raise ValueError("Extraction of {} not supported".format(from_path)) if remove_finished: os.remove(from_path) def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False): """Download and extract an archive file from a URL.""" download_root = os.path.expanduser(download_root) extract_root = extract_root or download_root filename = filename or os.path.basename(url) download_url(url, download_root, filename, md5) archive = os.path.join(download_root, filename) print("Extracting {} to {}".format(archive, extract_root)) extract_archive(archive, extract_root, remove_finished) def iterable_to_str(iterable): """Convert an iterable to a string representation.""" return "'" + "', '".join(map(str, iterable)) + "'" def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): """Verify that a string argument is valid and raise an error if not.""" if not isinstance(value, str): msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}." raise ValueError(msg) if valid_values is None: return value if value not in valid_values: msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}." raise ValueError(msg) return value