use-case-and-architecture/EdgeFLite/data_collection/helper_utils.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

174 lines
6.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import hashlib
import gzip
import tarfile
import zipfile
import urllib.request
from torch.utils.model_zoo import tqdm
def generate_update_progress_barr():
"""Generates a progress bar for tracking download progress."""
pbar = tqdm(total=None)
def update_progress_bar(count, block_size, total_size):
"""Updates the progress bar based on the downloaded data size."""
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return update_progress_bar
def compute_md5_checksum(fpath, chunk_size=1024 * 1024):
"""Calculates the MD5 checksum for a given file."""
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def verify_md5_checksum(fpath, md5):
"""Checks if the MD5 of a file matches the given checksum."""
return md5 == compute_md5_checksum(fpath)
def validate_integrity(fpath, md5=None):
"""Checks the integrity of a file by verifying its existence and MD5 checksum."""
if not os.path.isfile(fpath):
return False
return md5 is None or verify_md5_checksum(fpath, md5)
def download_url(url, root, filename=None, md5=None):
"""Download a file from a URL and save it in the specified directory."""
root = os.path.expanduser(root)
filename = filename or os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if validate_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
return
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
except (urllib.error.URLError, IOError) as e:
if url.startswith('https'):
url = url.replace('https:', 'http:')
print('Failed download. Retrying with http.')
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
else:
raise e
if not validate_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
def list_dir(root, prefix=False):
"""List all directories at the specified root."""
root = os.path.expanduser(root)
directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
return [os.path.join(root, d) for d in directories] if prefix else directories
def list_files(root, suffix, prefix=False):
"""List all files with a specific suffix in the specified root."""
root = os.path.expanduser(root)
files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)]
return [os.path.join(root, f) for f in files] if prefix else files
def fetch_file_google_drive(file_id, root, filename=None, md5=None):
"""Download a file from Google Drive and save it in the specified directory."""
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
filename = filename or file_id
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and validate_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
return
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
_store_response_content(response, fpath)
def _get_confirm_token(response):
"""Extract the download token from Google Drive cookies."""
return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None)
def _store_response_content(response, destination, chunk_size=32768):
"""Save the response content to a file in chunks."""
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
def extract_archive(from_path, to_path=None, remove_finished=False):
"""Extract an archive file (tar, zip, gz) to the specified path."""
if to_path is None:
to_path = os.path.dirname(from_path)
if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")):
mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else
'.xz' if from_path.endswith('.tar.xz') else '')
with tarfile.open(from_path, mode) as tar:
tar.extractall(path=to_path)
elif from_path.endswith(".gz"):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif from_path.endswith(".zip"):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.remove(from_path)
def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False):
"""Download and extract an archive file from a URL."""
download_root = os.path.expanduser(download_root)
extract_root = extract_root or download_root
filename = filename or os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable):
"""Convert an iterable to a string representation."""
return "'" + "', '".join(map(str, iterable)) + "'"
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
"""Verify that a string argument is valid and raise an error if not."""
if not isinstance(value, str):
msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}."
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}."
raise ValueError(msg)
return value