4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
174 lines
6.6 KiB
Python
174 lines
6.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import hashlib
|
|
import gzip
|
|
import tarfile
|
|
import zipfile
|
|
import urllib.request
|
|
from torch.utils.model_zoo import tqdm
|
|
|
|
def generate_update_progress_barr():
|
|
"""Generates a progress bar for tracking download progress."""
|
|
pbar = tqdm(total=None)
|
|
|
|
def update_progress_bar(count, block_size, total_size):
|
|
"""Updates the progress bar based on the downloaded data size."""
|
|
if pbar.total is None and total_size:
|
|
pbar.total = total_size
|
|
progress_bytes = count * block_size
|
|
pbar.update(progress_bytes - pbar.n)
|
|
|
|
return update_progress_bar
|
|
|
|
def compute_md5_checksum(fpath, chunk_size=1024 * 1024):
|
|
"""Calculates the MD5 checksum for a given file."""
|
|
md5 = hashlib.md5()
|
|
with open(fpath, 'rb') as f:
|
|
for chunk in iter(lambda: f.read(chunk_size), b''):
|
|
md5.update(chunk)
|
|
return md5.hexdigest()
|
|
|
|
def verify_md5_checksum(fpath, md5):
|
|
"""Checks if the MD5 of a file matches the given checksum."""
|
|
return md5 == compute_md5_checksum(fpath)
|
|
|
|
def validate_integrity(fpath, md5=None):
|
|
"""Checks the integrity of a file by verifying its existence and MD5 checksum."""
|
|
if not os.path.isfile(fpath):
|
|
return False
|
|
return md5 is None or verify_md5_checksum(fpath, md5)
|
|
|
|
def download_url(url, root, filename=None, md5=None):
|
|
"""Download a file from a URL and save it in the specified directory."""
|
|
root = os.path.expanduser(root)
|
|
filename = filename or os.path.basename(url)
|
|
fpath = os.path.join(root, filename)
|
|
|
|
os.makedirs(root, exist_ok=True)
|
|
|
|
if validate_integrity(fpath, md5):
|
|
print('Using downloaded and verified file: ' + fpath)
|
|
return
|
|
|
|
try:
|
|
print('Downloading ' + url + ' to ' + fpath)
|
|
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
|
|
except (urllib.error.URLError, IOError) as e:
|
|
if url.startswith('https'):
|
|
url = url.replace('https:', 'http:')
|
|
print('Failed download. Retrying with http.')
|
|
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
|
|
else:
|
|
raise e
|
|
|
|
if not validate_integrity(fpath, md5):
|
|
raise RuntimeError("File not found or corrupted.")
|
|
|
|
def list_dir(root, prefix=False):
|
|
"""List all directories at the specified root."""
|
|
root = os.path.expanduser(root)
|
|
directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
|
|
|
|
return [os.path.join(root, d) for d in directories] if prefix else directories
|
|
|
|
def list_files(root, suffix, prefix=False):
|
|
"""List all files with a specific suffix in the specified root."""
|
|
root = os.path.expanduser(root)
|
|
files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)]
|
|
|
|
return [os.path.join(root, f) for f in files] if prefix else files
|
|
|
|
def fetch_file_google_drive(file_id, root, filename=None, md5=None):
|
|
"""Download a file from Google Drive and save it in the specified directory."""
|
|
url = "https://docs.google.com/uc?export=download"
|
|
root = os.path.expanduser(root)
|
|
filename = filename or file_id
|
|
fpath = os.path.join(root, filename)
|
|
|
|
os.makedirs(root, exist_ok=True)
|
|
|
|
if os.path.isfile(fpath) and validate_integrity(fpath, md5):
|
|
print('Using downloaded and verified file: ' + fpath)
|
|
return
|
|
|
|
session = requests.Session()
|
|
response = session.get(url, params={'id': file_id}, stream=True)
|
|
token = _get_confirm_token(response)
|
|
|
|
if token:
|
|
params = {'id': file_id, 'confirm': token}
|
|
response = session.get(url, params=params, stream=True)
|
|
|
|
_store_response_content(response, fpath)
|
|
|
|
def _get_confirm_token(response):
|
|
"""Extract the download token from Google Drive cookies."""
|
|
return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None)
|
|
|
|
def _store_response_content(response, destination, chunk_size=32768):
|
|
"""Save the response content to a file in chunks."""
|
|
with open(destination, "wb") as f:
|
|
pbar = tqdm(total=None)
|
|
progress = 0
|
|
for chunk in response.iter_content(chunk_size):
|
|
if chunk: # filter out keep-alive new chunks
|
|
f.write(chunk)
|
|
progress += len(chunk)
|
|
pbar.update(progress - pbar.n)
|
|
pbar.close()
|
|
|
|
def extract_archive(from_path, to_path=None, remove_finished=False):
|
|
"""Extract an archive file (tar, zip, gz) to the specified path."""
|
|
if to_path is None:
|
|
to_path = os.path.dirname(from_path)
|
|
|
|
if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")):
|
|
mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else
|
|
'.xz' if from_path.endswith('.tar.xz') else '')
|
|
with tarfile.open(from_path, mode) as tar:
|
|
tar.extractall(path=to_path)
|
|
elif from_path.endswith(".gz"):
|
|
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
|
|
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
|
|
out_f.write(zip_f.read())
|
|
elif from_path.endswith(".zip"):
|
|
with zipfile.ZipFile(from_path, 'r') as z:
|
|
z.extractall(to_path)
|
|
else:
|
|
raise ValueError("Extraction of {} not supported".format(from_path))
|
|
|
|
if remove_finished:
|
|
os.remove(from_path)
|
|
|
|
def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False):
|
|
"""Download and extract an archive file from a URL."""
|
|
download_root = os.path.expanduser(download_root)
|
|
extract_root = extract_root or download_root
|
|
filename = filename or os.path.basename(url)
|
|
|
|
download_url(url, download_root, filename, md5)
|
|
archive = os.path.join(download_root, filename)
|
|
print("Extracting {} to {}".format(archive, extract_root))
|
|
extract_archive(archive, extract_root, remove_finished)
|
|
|
|
def iterable_to_str(iterable):
|
|
"""Convert an iterable to a string representation."""
|
|
return "'" + "', '".join(map(str, iterable)) + "'"
|
|
|
|
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
|
|
"""Verify that a string argument is valid and raise an error if not."""
|
|
if not isinstance(value, str):
|
|
msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}."
|
|
raise ValueError(msg)
|
|
|
|
if valid_values is None:
|
|
return value
|
|
|
|
if value not in valid_values:
|
|
msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}."
|
|
raise ValueError(msg)
|
|
|
|
return value
|