From 4ec0a23e73f05053c54eee81ea85019f12f7012b Mon Sep 17 00:00:00 2001 From: Weisen Pan Date: Wed, 18 Sep 2024 18:38:28 -0700 Subject: [PATCH] Edge Federated Learning for Improved Training Efficiency Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c --- EdgeFLite/.DS_Store | Bin 0 -> 10244 bytes EdgeFLite/README.md | 41 +++ EdgeFLite/architecture/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/architecture/coremodel.py | 245 +++++++++++++ EdgeFLite/architecture/mixup.py | 60 ++++ EdgeFLite/architecture/resnet.py | 237 +++++++++++++ EdgeFLite/architecture/resnet_sl.py | 312 +++++++++++++++++ EdgeFLite/architecture/splitnet.py | 212 ++++++++++++ EdgeFLite/configurations/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/configurations/training_config.py | 81 +++++ EdgeFLite/data_collection/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/data_collection/augment_auto.py | 103 ++++++ EdgeFLite/data_collection/augment_rand.py | 199 +++++++++++ EdgeFLite/data_collection/cifar100_noniid.py | 220 ++++++++++++ EdgeFLite/data_collection/cifar10_noniid.py | 179 ++++++++++ EdgeFLite/data_collection/data_cutout.py | 71 ++++ EdgeFLite/data_collection/dataset_cifar.py | 178 ++++++++++ EdgeFLite/data_collection/dataset_factory.py | 152 ++++++++ EdgeFLite/data_collection/dataset_imagenet.py | 194 +++++++++++ EdgeFLite/data_collection/directory_utils.py | 229 ++++++++++++ EdgeFLite/data_collection/helper_utils.py | 173 ++++++++++ EdgeFLite/data_collection/pill_data_base.py | 84 +++++ EdgeFLite/data_collection/pill_data_large.py | 83 +++++ EdgeFLite/data_collection/skin_dataset.py | 46 +++ EdgeFLite/data_collection/vision_utils.py | 94 +++++ EdgeFLite/debug_tool.py | 47 +++ EdgeFLite/fedml_service/.DS_Store | Bin 0 -> 6148 bytes .../fedml_service/architecture/.DS_Store | Bin 0 -> 6148 bytes .../fedml_service/architecture/cv/.DS_Store | Bin 0 -> 6148 bytes .../cv/models_pretrained/.DS_Store | Bin 0 -> 6148 bytes .../cv/models_pretrained/CIFAR10/.DS_Store | Bin 0 -> 6148 bytes .../CIFAR10/resnet56/test_metrics | Bin 0 -> 15273 bytes .../CIFAR10/resnet56/train_metrics | Bin 0 -> 15058 bytes .../cv/models_pretrained/CIFAR100/.DS_Store | Bin 0 -> 6148 bytes .../CIFAR100/resnet56/test_metrics | Bin 0 -> 15269 bytes .../CIFAR100/resnet56/train_metrics | Bin 0 -> 15058 bytes .../cv/models_pretrained/CINIC10/.DS_Store | Bin 0 -> 6148 bytes .../CINIC10/resnet56/test_metrics | Bin 0 -> 15256 bytes .../CINIC10/resnet56/train_metrics | Bin 0 -> 15058 bytes .../cv/resnet56_federated/.DS_Store | Bin 0 -> 6148 bytes .../cv/resnet56_federated/net_server.py | 196 +++++++++++ .../resnet56_federated/pretrained_weights.py | 326 ++++++++++++++++++ .../cv/resnet56_federated/resnet_client.py | 231 +++++++++++++ .../cv/resnet_federated/.DS_Store | Bin 0 -> 6148 bytes .../architecture/cv/resnet_federated/net.py | 211 ++++++++++++ .../fedml_service/data_cleaning/.DS_Store | Bin 0 -> 6148 bytes .../data_cleaning/cifar10/.DS_Store | Bin 0 -> 6148 bytes .../data_cleaning/cifar10/bulk_data_import.py | 230 ++++++++++++ .../data_cleaning/cifar10/dataset_hub.py | 109 ++++++ .../data_cleaning/cifar100/.DS_Store | Bin 0 -> 6148 bytes .../cifar100/bulk_data_import.py | 182 ++++++++++ .../data_cleaning/cifar100/dataset_hub.py | 157 +++++++++ .../data_cleaning/pillbase/.DS_Store | Bin 0 -> 6148 bytes .../pillbase/bulk_data_import.py | 82 +++++ .../data_cleaning/skin_dataset/.DS_Store | Bin 0 -> 6148 bytes .../skin_dataset/bulk_data_import.py | 97 ++++++ .../fedml_service/decentralized/.DS_Store | Bin 0 -> 6148 bytes .../decentralized/federated_gkt/.DS_Store | Bin 0 -> 6148 bytes .../federated_gkt/client_coach.py | 120 +++++++ .../federated_gkt/helper_utils.py | 108 ++++++ .../federated_gkt/server_coach.py | 274 +++++++++++++++ EdgeFLite/helpers/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/helpers/evaluation_metrics.py | 190 ++++++++++ EdgeFLite/helpers/normalization.py | 131 +++++++ EdgeFLite/helpers/optimizer_rmsprop.py | 129 +++++++ EdgeFLite/helpers/pace_controller.py | 146 ++++++++ EdgeFLite/helpers/preloader_module.py | 39 +++ EdgeFLite/helpers/report_summary.py | 186 ++++++++++ EdgeFLite/helpers/smoothing_labels.py | 48 +++ EdgeFLite/info_map.csv | 68 ++++ EdgeFLite/process_data.py | 47 +++ EdgeFLite/resnet_federated.py | 161 +++++++++ EdgeFLite/run_federated.py | 158 +++++++++ EdgeFLite/run_local.py | 279 +++++++++++++++ EdgeFLite/run_prox.py | 223 ++++++++++++ EdgeFLite/run_splitfed.py | 210 +++++++++++ EdgeFLite/scripts/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh | 48 +++ EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh | 60 ++++ EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh | 49 +++ EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh | 44 +++ .../scripts/EdgeFLite_W168_96c_650r32.sh | 51 +++ EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh | 43 +++ EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh | 44 +++ EdgeFLite/scripts/FGKT_R110_20c_650r.sh | 43 +++ EdgeFLite/scripts/FGKT_R110_20c_skew.sh | 56 +++ EdgeFLite/scripts/FGKT_W168_20c_300r.sh | 61 ++++ EdgeFLite/scripts/FGKT_W168_20c_skew.sh | 44 +++ EdgeFLite/scripts/FGKT_W502_20c_350r.sh | 60 ++++ EdgeFLite/settings.py | 7 + EdgeFLite/thop/.DS_Store | Bin 0 -> 6148 bytes EdgeFLite/thop/helper_utils.py | 38 ++ EdgeFLite/thop/hooks_basic.py | 91 +++++ EdgeFLite/thop/hooks_rnn.py | 195 +++++++++++ EdgeFLite/thop/profiling.py | 168 +++++++++ EdgeFLite/train_EdgeFLite.py | 205 +++++++++++ 96 files changed, 8885 insertions(+) create mode 100644 EdgeFLite/.DS_Store create mode 100644 EdgeFLite/README.md create mode 100644 EdgeFLite/architecture/.DS_Store create mode 100644 EdgeFLite/architecture/coremodel.py create mode 100644 EdgeFLite/architecture/mixup.py create mode 100644 EdgeFLite/architecture/resnet.py create mode 100644 EdgeFLite/architecture/resnet_sl.py create mode 100644 EdgeFLite/architecture/splitnet.py create mode 100644 EdgeFLite/configurations/.DS_Store create mode 100644 EdgeFLite/configurations/training_config.py create mode 100644 EdgeFLite/data_collection/.DS_Store create mode 100644 EdgeFLite/data_collection/augment_auto.py create mode 100644 EdgeFLite/data_collection/augment_rand.py create mode 100644 EdgeFLite/data_collection/cifar100_noniid.py create mode 100644 EdgeFLite/data_collection/cifar10_noniid.py create mode 100644 EdgeFLite/data_collection/data_cutout.py create mode 100644 EdgeFLite/data_collection/dataset_cifar.py create mode 100644 EdgeFLite/data_collection/dataset_factory.py create mode 100644 EdgeFLite/data_collection/dataset_imagenet.py create mode 100644 EdgeFLite/data_collection/directory_utils.py create mode 100644 EdgeFLite/data_collection/helper_utils.py create mode 100644 EdgeFLite/data_collection/pill_data_base.py create mode 100644 EdgeFLite/data_collection/pill_data_large.py create mode 100644 EdgeFLite/data_collection/skin_dataset.py create mode 100644 EdgeFLite/data_collection/vision_utils.py create mode 100644 EdgeFLite/debug_tool.py create mode 100644 EdgeFLite/fedml_service/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/train_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet56_federated/net_server.py create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store create mode 100644 EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/.DS_Store create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar10/bulk_data_import.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar100/bulk_data_import.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store create mode 100644 EdgeFLite/fedml_service/data_cleaning/pillbase/bulk_data_import.py create mode 100644 EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store create mode 100644 EdgeFLite/fedml_service/data_cleaning/skin_dataset/bulk_data_import.py create mode 100644 EdgeFLite/fedml_service/decentralized/.DS_Store create mode 100644 EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store create mode 100644 EdgeFLite/fedml_service/decentralized/federated_gkt/client_coach.py create mode 100644 EdgeFLite/fedml_service/decentralized/federated_gkt/helper_utils.py create mode 100644 EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py create mode 100644 EdgeFLite/helpers/.DS_Store create mode 100644 EdgeFLite/helpers/evaluation_metrics.py create mode 100644 EdgeFLite/helpers/normalization.py create mode 100644 EdgeFLite/helpers/optimizer_rmsprop.py create mode 100644 EdgeFLite/helpers/pace_controller.py create mode 100644 EdgeFLite/helpers/preloader_module.py create mode 100644 EdgeFLite/helpers/report_summary.py create mode 100644 EdgeFLite/helpers/smoothing_labels.py create mode 100644 EdgeFLite/info_map.csv create mode 100644 EdgeFLite/process_data.py create mode 100644 EdgeFLite/resnet_federated.py create mode 100644 EdgeFLite/run_federated.py create mode 100644 EdgeFLite/run_local.py create mode 100644 EdgeFLite/run_prox.py create mode 100644 EdgeFLite/run_splitfed.py create mode 100644 EdgeFLite/scripts/.DS_Store create mode 100644 EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh create mode 100644 EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh create mode 100644 EdgeFLite/scripts/FGKT_R110_20c_650r.sh create mode 100644 EdgeFLite/scripts/FGKT_R110_20c_skew.sh create mode 100644 EdgeFLite/scripts/FGKT_W168_20c_300r.sh create mode 100644 EdgeFLite/scripts/FGKT_W168_20c_skew.sh create mode 100644 EdgeFLite/scripts/FGKT_W502_20c_350r.sh create mode 100644 EdgeFLite/settings.py create mode 100644 EdgeFLite/thop/.DS_Store create mode 100644 EdgeFLite/thop/helper_utils.py create mode 100644 EdgeFLite/thop/hooks_basic.py create mode 100644 EdgeFLite/thop/hooks_rnn.py create mode 100644 EdgeFLite/thop/profiling.py create mode 100644 EdgeFLite/train_EdgeFLite.py diff --git a/EdgeFLite/.DS_Store b/EdgeFLite/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fd8c7022aa7e084233ab858fb2437a7f4f385e18 GIT binary patch literal 10244 zcmeHMU2GIp6uxJ=(AlA-)AHM$02>-9mR7oGTYeO_e^3IX$hNc;aGl+mwi9M&)|uH| zupu@QeL(z)FFvUF^J+BlK?6jM`XIcDjgbdU)WjD}L?4Zbo;!E8>`v(tjUY5L$=oyd z-gEZe^UZhexp$Tj0zEk;Mo5$p!sAV)ScAn)F4|A;ITz=(NGJjA3DJl{RFcLrH%B{Y z2oMMm2oMMm2oMMmxEm0_oJ}u&36nt?AP^uBATWynT_2)&QyKTv-PsTkt&7^CJzsXJy;F*GV!~kZF{ZXopGVaM~CYd<_W=_Bx8N3q;xTB+8 z)Kw=CXEG=Q1Of!+B0$gXHAE&Bp4unN&hHjcQw2Ml(#;g=Tu5bg&34+`zXq--yH1$;RSxz0+)Iw5R-fJ3;tQo518fc1&D#u1NLoo|I zX3kJ-*EYdA_%I(%imRul+B(`>qn!_)ZjDaGI$~X|(XAcZPM;3*E1EX%IFvl0XAJW! za|ygF6xP_OcB)EmD558KA92FtVAZ$ruzV_Hjcl=yGg(}DOX0Y9$CFg!)YO(3 zCaG9gN2M%9vk+eMW~p=7S&hL^;#N5xWnH?R?_~>RRb$cPZ;EL-EG&_R4-2D7u_2GS zcQET{rePmeEgNN0tg|sk?@gP=K})6mBWY95Wehth)~97XozrDU?bh{#_M~c$I_gP> zK8@L?<@CCWNm0z%>LC^Nw7lAf+V-f1KS$M;yH4IHApBm(6N|=q)*(FFLd4{dH&uiS#AV{N__%i6C z?X~TSOVVR>*!{BY3}#gWr*y>|yAS)u>`B>D2Mf4E@>&P5)qP#qPL28%%`x3hL}NdC zafJR{x1zW{tX{dQR(e!Q@Zwb-Uy)|yHQbfq8)us{zA4_q-b)!O;ih=I_quq}NFKnm z)&Mz%t4x-hA?L{Jw(xF1?z z6FdZQ=!PEH2M6I0Bwz@RLkbj7VH`}zK><#}DR>&r!Z~;zUVvBNO?V65hIil+d;}lE zCvX|A!dEZ@-@>nO9e#s9xDdC9Tg=sQBG3Ts!v&x0B7AEAIP_+@KwO;`U(S zizo~mZ#T<1Hw&9c|88O6t_YiJH7rn;Enl&=X=B@#&K)gGvmsswC_ThehI6BHfW*9s@(<|2<>hCzr^laZhiSe5aLiJjy)& IPtN~;1NozGbN~PV literal 0 HcmV?d00001 diff --git a/EdgeFLite/README.md b/EdgeFLite/README.md new file mode 100644 index 0000000..7e9e246 --- /dev/null +++ b/EdgeFLite/README.md @@ -0,0 +1,41 @@ +# EdgeFLite:Edge Federated Learning for Improved Training Efficiency + + +- EdgeFLite is a cutting-edge framework developed to tackle the memory limitations of federated learning (FL) on edge devices with restricted resources. By partitioning large convolutional neural networks (CNNs) into smaller sub-models and distributing the training across local clients, EdgeFLite ensures efficient learning while maintaining data privacy. Clients in clusters collaborate by sharing learned representations, which are then aggregated by a central server to refine the global model. Experimental results on medical imaging and natural datasets demonstrate that EdgeFLite consistently outperforms other FL frameworks, setting new benchmarks for performance. + +- Within 6G-enabled mobile edge computing (MEC) networks, EdgeFLite addresses the challenges posed by client diversity and resource constraints. It optimizes local models and resource allocation to improve overall efficiency. Through a detailed convergence analysis, this research establishes a clear relationship between training loss and resource usage. The innovative Intelligent Frequency Band Allocation (IFBA) algorithm minimizes latency and enhances training efficiency by 5-10%, making EdgeFLite a robust solution for improving federated learning across a wide range of edge environments. + +## Preparation +### Dataset Setup +- The CIFAR-10 and CIFAR-100 datasets, both derived from the Tiny Images dataset, will be automatically downloaded. CIFAR-10 includes 60,000 32x32 color images across 10 categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images per category, split into 5,000 for training and 1,000 for testing. + +- CIFAR-100 is a more complex dataset, featuring 100 categories with fewer images per class compared to CIFAR-10. These datasets serve as standard benchmarks for image classification tasks and provide a robust evaluation environment for machine learning models. + +### Dependency Installation + +```bash +Pytorch 1.10.2 +OpenCV 4.5.5 +``` + +## Running Experiments +*Top-1 accuracy (%) of FedDCT compared to state-of-the-art FL methods on the CIFAR-10 and CIFAR-100 test datasets.* + +1. **Specify Experiment Name:** + Add `--spid` to specify the experiment name in each training script, like this: + ```bash + python run_gkt.py --is_fed=1 --fixed_cluster=0 --split_factor=1 --num_clusters=20 --num_selected=20 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1 + ``` + +2. **Training Scripts for CIFAR-10:** + + - **Centralized Training:** + ```bash + python run_local.py --is_fed=0 --split_factor=1 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --epochs=300 + ``` + + - **FedDCT:** + ```bash + python train_EdgeFLite.py --is_fed=1 --fixed_cluster=0 --split_factor=4 --num_clusters=5 --num_selected=5 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1 + ``` +--- \ No newline at end of file diff --git a/EdgeFLite/architecture/.DS_Store b/EdgeFLite/architecture/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..3dfe47013416b6b4a4f800ed361f066efc266db6 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8O(;SS3LZQxc&*q!X~9dV^#zRRK_wpTZ@zH8AR#e83=jib!GO8ktjbpGhBiSA5Cgx-0NxJ*6wxu5 zX;fDS45|eHEP+`G%(0ii9AVHgm}vwL2-m5AI+dFigX?s#3lrxU%rxqB#!d6V&6~OD zP`G+KtS?kJm%W2L6 1 else False + self.ensembled_loss_weight = args.ensembled_loss_weight + self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False + self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False + self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False + self.cot_weight = args.cot_weight + self.is_cot_weight_warm_up = args.is_cot_weight_warm_up + self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs + self.cot_loss_choose = args.cot_loss_choose + + # Model arguments for the proxy client + model_kwargs = { + 'num_classes': self.num_classes, + 'norm_layer': norm_layer, + 'dataset': args.dataset, + 'split_factor': self.split_factor, + 'output_stride': args.output_stride + } + + # Initialize multiple instances of the network architecture for the proxy client + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + self.proxy_clients_models = nn.ModuleList( + [_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[1] + for _ in range(self.loop_factor)] + ) + else: + raise NotImplementedError(f"Architecture '{self.arch}' not implemented.") + + # Identical initialization of the model if specified + if args.is_identical_init: + print("INFO:PyTorch: Using identical initialization.") + self._identical_init() + + def forward(self, main_client_outputs, y_a=None, y_b=None, lam=None, target=None, mode='train', epoch=0, streams=None): + """Forward pass for the proxy client. Manages multiple sub-networks and ensemble outputs.""" + outputs = [] + ce_losses = [] + + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + if mode == 'train': + # Calculate loss and forward pass during training + for i in range(self.loop_factor): + output = self.proxy_clients_models[i](main_client_outputs[i]) + loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target) + outputs.append(output) + ce_losses.append(loss) + + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses, dim=0)) + + # Calculate co-training loss if enabled + if self.is_cot_loss: + cot_loss = self._calculate_co_training_loss(outputs, epoch) + else: + cot_loss = torch.zeros_like(ce_loss) + + return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss + + elif mode in ['val', 'test']: + # Forward pass during evaluation or testing + for i in range(self.loop_factor): + output = self.proxy_clients_models[i](main_client_outputs[i]) + loss = self.criterion(output, target) if self.criterion else torch.zeros(1) + outputs.append(output) + ce_losses.append(loss) + + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses, dim=0)) + return ensemble_output, torch.stack(outputs, dim=0), ce_loss + else: + # Return a dummy tensor if the mode is unsupported + return torch.ones(1) + else: + raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.") + + def _collect_ensemble_output(self, outputs): + """Calculate the ensemble output from multiple sub-networks.""" + stacked_outputs = torch.stack(outputs, dim=0) + + # Apply softmax to the outputs before ensembling if specified + if self.is_ensembled_after_softmax: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(F.softmax(stacked_outputs, dim=-1), dim=0) + else: + ensemble_output = torch.mean(F.softmax(stacked_outputs, dim=-1), dim=0) + else: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(stacked_outputs, dim=0) + else: + ensemble_output = torch.mean(stacked_outputs, dim=0) + + return ensemble_output + + def _calculate_co_training_loss(self, outputs, epoch): + """Calculate the co-training loss between outputs of different sub-networks.""" + # Adjust the weight of the co-training loss during warm-up epochs + weight_now = self.cot_weight if not self.is_cot_weight_warm_up or epoch >= self.cot_weight_warm_up_epochs else max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005) + + # Different methods of calculating co-training loss + if self.cot_loss_choose == 'js_divergence': + outputs_all = torch.stack(outputs, dim=0) + p_all = F.softmax(outputs_all, dim=-1) + p_mean = torch.mean(p_all, dim=0) + H_mean = (-p_mean * torch.log(p_mean)).sum(-1).mean() + H_sep = (-p_all * F.log_softmax(outputs_all, dim=-1)).sum(-1).mean() + return weight_now * (H_mean - H_sep) + elif self.cot_loss_choose == 'kl_separate': + outputs_all = torch.stack(outputs, dim=0) + outputs_r1 = torch.repeat_interleave(outputs_all, self.split_factor - 1, dim=0) + index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i] + outputs_r2 = torch.index_select(outputs_all, dim=0, index=torch.tensor(index_list, dtype=torch.long).cuda()) + kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2, dim=-1).detach(), reduction='none') + return weight_now * kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1) + else: + raise NotImplementedError(f"Co-training loss '{self.cot_loss_choose}' not implemented.") + + def _identical_init(self): + """Ensure identical initialization of weights for sub-networks.""" + with torch.no_grad(): + # Copy weights from the first model to all subsequent models + for i in range(1, self.split_factor): + for (name1, param1), (name2, param2) in zip(self.proxy_clients_models[i].named_parameters(), + self.proxy_clients_models[0].named_parameters()): + if 'weight' in name1: + param1.data.copy_(param2.data) diff --git a/EdgeFLite/architecture/mixup.py b/EdgeFLite/architecture/mixup.py new file mode 100644 index 0000000..6ce96f3 --- /dev/null +++ b/EdgeFLite/architecture/mixup.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import numpy as np + +@torch.no_grad() +def combine_mixup_data(x, y, alpha=1.0, use_cuda=True): + """ + Perform the mixup operation on input data. + + Args: + x (Tensor): Input features, typically from the dataset. + y (Tensor): Input labels corresponding to the features. + alpha (float): Mixup interpolation coefficient. The default value is 1.0. + A higher value results in more mixing between samples. + use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available. + + Returns: + mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x. + y_a (Tensor): Original input labels corresponding to x. + y_b (Tensor): Permuted input labels corresponding to the mixed samples. + lam (float): The lambda value used for interpolation between samples. + """ + # Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup) + lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 + + # Get the batch size from the input tensor + batch_size = x.size(0) + + # Generate a random permutation of indices for mixing + # Use CUDA if available, otherwise stick with CPU + index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size) + + # Mix the features of the original and permuted samples using the lambda value + mixed_x = lam * x + (1 - lam) * x[index, :] + + # Assign original and permuted labels to y_a and y_b, respectively + y_a, y_b = y, y[index] + + # Return mixed features, original and permuted labels, and the lambda value + return mixed_x, y_a, y_b, lam + + +def mixup_loss_criterion(criterion, pred, y_a, y_b, lam): + """ + Compute the mixup loss using the provided criterion. + + Args: + criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss). + pred (Tensor): The model predictions, typically the output of a neural network. + y_a (Tensor): The original labels corresponding to the original input features. + y_b (Tensor): The permuted labels corresponding to the mixed input features. + lam (float): The lambda value for mixup, used to interpolate between the two losses. + + Returns: + loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses. + """ + # Compute the mixup loss by combining the loss from the original and permuted labels + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) diff --git a/EdgeFLite/architecture/resnet.py b/EdgeFLite/architecture/resnet.py new file mode 100644 index 0000000..99ab705 --- /dev/null +++ b/EdgeFLite/architecture/resnet.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +# Try to import the method to load model weights from a URL, with a fallback in case of ImportError +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# List of available ResNet architectures +__all__ = ['resnet_model_18', 'resnet_model_34', 'resnet_model_50', + 'resnet_model_101', 'resnet_model_152', 'resnet_model_200', + 'resnet110', 'resnet164', + 'resnext29_8x64d', 'resnext29_16x64d', + 'resnext50_32x4d', 'resnext101_32x4d', + 'resnext101_32x8d', 'resnext101_64x4d', + 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2', + 'wide_resnet16_8', 'wide_resnet52_8', 'wide_resnet16_12', + 'wide_resnet28_10', 'wide_resnet40_10'] + +# Pre-trained model URLs for various ResNet variants +model_urls = { + 'resnet_model_18': 'https://download.pytorch.org/models/resnet_model_18-5c106cde.pth', + 'resnet_model_34': 'https://download.pytorch.org/models/resnet_model_34-333f7ec4.pth', + 'resnet_model_50': 'https://download.pytorch.org/models/resnet_model_50-19c8e357.pth', + 'resnet_model_101': 'https://download.pytorch.org/models/resnet_model_101-5d3b4d8f.pth', + 'resnet_model_152': 'https://download.pytorch.org/models/resnet_model_152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet_model_50_2': 'https://download.pytorch.org/models/wide_resnet_model_50_2-95faca4d.pth', + 'wide_resnet_model_101_2': 'https://download.pytorch.org/models/wide_resnet_model_101_2-32ee1156.pth', +} + +# Function for a 3x3 convolution with padding +def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function for a 1x1 convolution +def apply_1x1_convolution(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +# BasicBlock class for the ResNet architecture +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor for the output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + # If norm_layer is not provided, use BatchNorm2d as the default + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Ensure BasicBlock is restricted to specific parameters + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock is restricted to groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("BasicBlock does not support dilation greater than 1") + + # Define the layers for the BasicBlock + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution + self.bn1 = norm_layer(planes) # First BatchNorm layer + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution + self.bn2 = norm_layer(planes) # Second BatchNorm layer + self.downsample = downsample # Optional downsample layer + self.stride = stride + + # Define the forward pass for BasicBlock + def forward(self, x): + identity = x # Save the input for the skip connection + + out = self.conv1(x) # First convolution + out = self.bn1(out) # BatchNorm after first convolution + out = self.relu(out) # ReLU activation + + out = self.conv2(out) # Second convolution + out = self.bn2(out) # BatchNorm after second convolution + + # Apply downsample if defined + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add the skip connection + out = self.relu(out) # Apply ReLU activation again + + return out + +# Bottleneck class for the ResNet architecture, a more complex block used in deeper ResNet models +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor for the output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups # Calculate width based on base width and groups + + # Define the layers for the Bottleneck block + self.conv1 = apply_1x1_convolution(inplanes, width) # 1x1 convolution to reduce the dimensions + self.bn1 = norm_layer(width) # BatchNorm after 1x1 convolution + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # 3x3 convolution + self.bn2 = norm_layer(width) # BatchNorm after 3x3 convolution + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # 1x1 convolution to expand the dimensions + self.bn3 = norm_layer(planes * self.expansion) # BatchNorm after final 1x1 convolution + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Optional downsample layer + self.stride = stride + + # Define the forward pass for Bottleneck + def forward(self, x): + identity = x # Save the input for the skip connection + + out = self.conv1(x) # First convolution + out = self.bn1(out) # BatchNorm after first convolution + out = self.relu(out) # ReLU activation + + out = self.conv2(out) # Second convolution + out = self.bn2(out) # BatchNorm after second convolution + out = self.relu(out) # ReLU activation + + out = self.conv3(out) # Third convolution + out = self.bn3(out) # BatchNorm after third convolution + + # Apply downsample if defined + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add the skip connection + out = self.relu(out) # Apply ReLU activation again + + return out + +# Main ResNet class, a customizable deep learning model architecture +class ResNet(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer + self._norm_layer = norm_layer + + self.groups = groups # Number of groups in convolutions + self.inplanes = 16 if dataset in ['cifar10', 'cifar100'] else 64 # Adjust initial planes for CIFAR + + # First layer: a combination of convolution, normalization, and ReLU + self.layer0 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(self.inplanes), + nn.ReLU(inplace=True), + ) + + # Subsequent ResNet layers using the _create_model_layer method + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling + self.fc = nn.Linear(64 * block.expansion, num_classes) # Fully connected layer for classification + + # Initialization for model weights + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 1e-3) + + # Zero-initialize the last BatchNorm in residual connections if required + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Helper function to create layers in ResNet + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer # Set normalization layer + downsample = None + # If the stride is not 1 or input/output planes do not match, create a downsample layer + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [block(self.inplanes, planes, stride, downsample)] # Create the first block with downsampling + self.inplanes = planes * block.expansion # Update inplanes for next blocks + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) # Add subsequent blocks without downsampling + return nn.Sequential(*layers) + + # Forward pass through the ResNet architecture + def forward(self, x): + x = self.layer0(x) # Pass input through the first layer + x = self.layer1(x) # First ResNet layer + x = self.layer2(x) # Second ResNet layer + x = self.layer3(x) # Third ResNet layer + x = self.avgpool(x) # Global average pooling + x = torch.flatten(x, 1) # Flatten the output for the fully connected layer + x = self.fc(x) # Pass through the fully connected layer + return x + +# Helper function to instantiate ResNet with pretrained weights if available +def _resnet(arch, block, layers, models_pretrained, progress, **kwargs): + model = ResNet(arch, block, layers, **kwargs) # Create a ResNet model + if models_pretrained: # Load pretrained weights if requested + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +# Functions to create specific ResNet variants +def resnet_model_18(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_18', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs) + +def resnet_model_34(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_34', BasicBlock, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +def resnet_model_50(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_50', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +def resnet_model_101(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_101', Bottleneck, [3, 4, 23, 3], models_pretrained, progress, **kwargs) + +def resnet_model_152(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_152', Bottleneck, [3, 8, 36, 3], models_pretrained, progress, **kwargs) + +def resnet_model_200(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_200', Bottleneck, [3, 24, 36, 3], models_pretrained, progress, **kwargs) diff --git a/EdgeFLite/architecture/resnet_sl.py b/EdgeFLite/architecture/resnet_sl.py new file mode 100644 index 0000000..2f9bdd9 --- /dev/null +++ b/EdgeFLite/architecture/resnet_sl.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Importing necessary PyTorch libraries +import torch +import torch.nn as nn + +# Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Specify all the modules and functions to export +__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8'] + +# Function for 3x3 convolution with padding +def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function for 1x1 convolution, typically used to change the number of channels +def apply_1x1_convolution(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +# Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34) +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor for output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + # BasicBlock only supports groups=1 and base_width=64 + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("BasicBlock does not support dilation greater than 1") + + # Define two 3x3 convolution layers with batch normalization and ReLU activation + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(planes, planes) + self.bn2 = norm_layer(planes) + # Optional downsample layer for changing the dimensions + self.downsample = downsample + self.stride = stride + + # Forward function defining the data flow through the block + def forward(self, x): + identity = x # Save the input for residual connection + + # First convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second convolution, batch norm + out = self.conv2(out) + out = self.bn2(out) + + # Apply downsample if needed to match dimensions for residual addition + if self.downsample is not None: + identity = self.downsample(x) + + # Residual connection (add identity to output) + out += identity + out = self.relu(out) + + return out + +# Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101) +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor for output channels (output = input * 4) + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + # Width of the block based on base_width and groups + width = int(planes * (base_width / 64.)) * groups + + # Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation + self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample # Downsample layer for dimension adjustment + self.stride = stride + + # Forward function defining the data flow through the bottleneck block + def forward(self, x): + identity = x # Save the input for residual connection + + # First 1x1 convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second 3x3 convolution, batch norm, and ReLU + out = self.conv2(x) + out = self.bn2(out) + out = self.relu(out) + + # Third 1x1 convolution, batch norm + out = self.conv3(x) + out = self.bn3(out) + + # Apply downsample if needed for residual connection + if self.downsample is not None: + identity = self.downsample(x) + + # Residual connection (add identity to output) + out += identity + out = self.relu(out) + + return out + +# ResNet model for the main client (usually the primary model) +class PrimaryResNetClient(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(PrimaryResNetClient, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self._norm_layer = norm_layer + + # Initialize the number of input channels based on the dataset and split factor + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, + 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, + } + self.inplanes = inplanes_dict[dataset][split_factor] + + # Adjust input planes if using a wide ResNet + if 'wide_resnet' in arch: + widen_factor = int(arch.split('_')[-1]) + self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) + + self.base_width = width_per_group + self.dilation = 1 + replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False] + + # Check if replace_stride_with_dilation is properly defined + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements") + + # Initialize input layer depending on the dataset (small or large) + if dataset in ['skin_dataset', 'pill_base', 'medical_images']: + self.layer0 = self._initialize_primary_layer_large() + else: + self.layer0 = self._init_layer0_small() + + # Initialize model weights + self._init_model_weights(zero_init_residual) + + # Define the large initial convolution layer for large datasets + def _initialize_primary_layer_large(self): + return nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) + + # Define the small initial convolution layer for smaller datasets like CIFAR + def _init_layer0_small(self): + return nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True), + ) + + # Function to initialize weights in the network + def _init_model_weights(self, zero_init_residual): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Initialize residual weights for Bottleneck and BasicBlock if specified + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Define forward pass for the model + def forward(self, x): + x = self.layer0(x) + return x + +# ResNet model for proxy clients (usually assisting the main model) +class ResNetProxies(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(ResNetProxies, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self._norm_layer = norm_layer + + # Set input channels based on architecture, dataset, and split factor + self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group) + self.base_width = width_per_group + + # Define layers of the network (layer1, layer2, layer3) + self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1) + self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer + self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) + + # Initialize model weights + self._init_model_weights(zero_init_residual) + + # Set input channels based on dataset and split factor + def _set_input_planes(self, arch, dataset, split_factor, width_per_group): + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + } + inplanes = inplanes_dict[dataset][split_factor] + + # Adjust input planes for wide ResNet + if 'wide_resnet' in arch: + widen_factor = float(arch.split('_')[-1]) + inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) + + return inplanes + + # Function to create layers of the network (consisting of blocks) + def _create_model_layer(self, block, planes, blocks, stride=1): + layers = [block(self.inplanes, planes, stride)] # First block + self.inplanes = planes * block.expansion # Update input planes + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) # Additional blocks + return nn.Sequential(*layers) + + # Initialize weights in the network + def _init_model_weights(self, zero_init_residual): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Initialize residual weights for Bottleneck and BasicBlock if specified + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Define forward pass for the model + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + +# Helper function to create the main ResNet client +def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs): + return PrimaryResNetClient(arch, block, layers, **kwargs) + +# Helper function to create the proxy ResNet client +def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs): + return ResNetProxies(arch, block, layers, **kwargs) + +# Function to define a ResNet-110 model for main and proxy clients +def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs): + assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used + return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs) + +# Function to define a Wide ResNet-50-2 model for main and proxy clients +def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs): + kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet + return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +# Function to define a Wide ResNet-16-8 model for main and proxy clients +def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs): + kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet + return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs) diff --git a/EdgeFLite/architecture/splitnet.py b/EdgeFLite/architecture/splitnet.py new file mode 100644 index 0000000..d1f8821 --- /dev/null +++ b/EdgeFLite/architecture/splitnet.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn import ensemble +from .mixup import mixup_loss_criterion, combine_mixup_data +from . import resnet, resnet_sl + +__all__ = ['coremodel'] + +def _retrieve_network(arch='wide_resnet28_10'): + """ + Get the network architecture based on the provided name. + + Args: + arch (str): Name of the architecture. + + Returns: + Callable: The network class or function corresponding to the given architecture. + """ + networks = { + 'wide_resnet28_10': resnet.wide_resnet28_10, + 'wide_resnet16_8': resnet.wide_resnet16_8, + 'resnet110': resnet.resnet110, + 'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2 + } + if arch not in networks: + raise ValueError(f"Architecture {arch} is not supported.") + return networks[arch] + +class coremodel(nn.Module): + def __init__(self, args, norm_layer=None, criterion=None, progress=True): + """ + Initialize the coremodel model with multiple sub-networks. + + Args: + args (argparse.Namespace): Configuration arguments. + norm_layer (callable, optional): Normalization layer. + criterion (callable, optional): Loss function. + progress (bool): Whether to show progress. + """ + super(coremodel, self).__init__() + + # Configuration parameters + self.split_factor = args.split_factor + self.arch = args.arch + self.loop_factor = args.loop_factor + self.is_train_sep = args.is_train_sep + self.epochs = args.epochs + self.criterion = criterion + self.is_diff_data_train = args.is_diff_data_train + self.is_mixup = args.is_mixup + self.mix_alpha = args.mix_alpha + + # Define model architectures + valid_archs = [ + 'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200', + 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', + 'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164', + 'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10', + 'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2' + ] + + if self.arch not in valid_archs: + raise NotImplementedError(f"Architecture {self.arch} is not implemented.") + + model_args = { + 'num_classes': args.num_classes, + 'norm_layer': norm_layer, + 'dataset': args.dataset, + 'split_factor': self.split_factor, + 'output_stride': args.output_stride + } + # Initialize multiple sub-models based on the loop factor + self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)]) + + if args.is_identical_init: + print("INFO: Using identical initialization.") + self._identical_init() + + # Ensemble settings + self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False + self.ensembled_loss_weight = args.ensembled_loss_weight + self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False + self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False + + # Co-training settings + self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False + self.cot_weight = args.cot_weight + self.is_cot_weight_warm_up = args.is_cot_weight_warm_up + self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs + self.cot_loss_choose = args.cot_loss_choose + print(f"INFO: The co-training loss is {self.cot_loss_choose}.") + self.num_classes = args.num_classes + + def forward(self, x, target=None, mode='train', epoch=0, streams=None): + """ + Forward pass through the model with optional mixup and co-training loss. + + Args: + x (Tensor): Input tensor. + target (Tensor, optional): Target tensor for loss computation. + mode (str): Mode of operation ('train', 'val', or 'test'). + epoch (int): Current epoch. + streams (optional): Additional data streams. + + Returns: + Tuple: + - ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes]. + - outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes]. + - ce_loss (Tensor): Sum of cross-entropy losses for each model. + - cot_loss (Tensor): Co-training loss if applicable. + """ + outputs, ce_losses = [], [] + + if 'train' in mode: + if self.is_mixup: + x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha) + + # Split input data based on the loop factor + all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] + + for i in range(self.loop_factor: + x_input = all_x[i] + output = self.models[i](x_input) + loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target) + outputs.append(output) + ce_losses.append(loss) + + elif mode in ['val', 'test']: + for i in range(self.loop_factor: + output = self.models[i](x) + loss = self.criterion(output, target) if self.criterion else torch.zeros(1) + outputs.append(output) + ce_losses.append(loss) + + else: + return torch.ones(1), None, None, None + + # Calculate ensemble output and losses + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses)) + + if mode in ['val', 'test']: + return ensemble_output, torch.stack(outputs, dim=0), ce_loss + + if self.is_cot_loss: + cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch) + else: + cot_loss = torch.zeros_like(ce_loss) + + return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss + + def _collect_ensemble_output(self, outputs): + """ + Calculate the ensemble output from a list of tensors. + + Args: + outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. + + Returns: + Tensor: The ensemble output with shape [batch_size, num_classes]. + """ + stacked_outputs = torch.stack(outputs, dim=0) + + if self.is_ensembled_after_softmax: + softmax_outputs = F.softmax(stacked_outputs, dim=-1) + if self.is_max_ensemble: + ensemble_output, _ = torch.max(softmax_outputs, dim=0) + else: + ensemble_output = torch.mean(softmax_outputs, dim=0) + else: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(stacked_outputs, dim=0) + else: + ensemble_output = torch.mean(stacked_outputs, dim=0) + + return ensemble_output + + def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0): + """ + Calculate the co-training loss between outputs of different networks. + + Args: + outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. + loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate'). + epoch (int): Current epoch. + + Returns: + Tensor: The computed co-training loss. + """ + weight_now = self.cot_weight + if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs: + weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005) + + stacked_outputs = torch.stack(outputs, dim=0) + + if loss_choose == 'js_divergence': + p_all = F.softmax(stacked_outputs, dim=-1) + p_mean = torch.mean(p_all, dim=0) + H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean() + H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean() + cot_loss = weight_now * (H_mean - H_sep) + + elif loss_choose == 'kl_seperate': + outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0) + index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i] + outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device)) + kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,” diff --git a/EdgeFLite/configurations/.DS_Store b/EdgeFLite/configurations/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..32fb7410fcefd288d53cab8291e21caf24976639 GIT binary patch literal 6148 zcmeHKOG*Pl5UtV(5fQR<*;mL7CSg25wnB7f#4ILsK(g->WBBT0FxY}_M5GG3 zUw3`_^M>i_CL&($hdI%Vh^8olERBkYx@tOe=L;a~7(H$1k=C>xo+|_WMU|}mQdUeo zJ)j)d-;Z6+qb#{%9-ro`K5uuvqwBiYHn(jz<_=cyS6&ZyFN?1{hi`iBHXmn}&lVC4 z1Ovf9Fc1v&8&`UC(@pjJU!FQvdF3Si?{ z455Lrg#s;<{fWUA4t;XJ#<3V$II+%*aeU_Y#|!(Fn3K8_H-=FM1Hr&41LxMAN&dgZ zuT*{HpN2#)7zhUb83R0NSM3s;^4$7qdy;Dt$~}sR_!VhD(AHE8)KrX)oFg}CG<%XZ Y?2ThFlu<;l;lQ{EC?Qb=1HZt)8=4(8A^-pY literal 0 HcmV?d00001 diff --git a/EdgeFLite/configurations/training_config.py b/EdgeFLite/configurations/training_config.py new file mode 100644 index 0000000..28ff6c0 --- /dev/null +++ b/EdgeFLite/configurations/training_config.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +from __future__ import absolute_import, division, print_function +import json +import torch +from config import * + +# Function to save hyperparameters into a JSON file +def store_hyperparameters_json(args): + """Save hyperparameters to a JSON file.""" + # Create the model directory if it does not exist + os.makedirs(args.model_dir, exist_ok=True) + # Determine the filename based on whether it's evaluation or training mode + filename = os.path.join(args.model_dir, 'hparams_eval.json' if args.evaluate else 'hparams_train.json') + # Convert the arguments to a dictionary + hparams = vars(args) + # Write the hyperparameters to a JSON file with indentation and sorted keys + with open(filename, 'w') as f: + json.dump(hparams, f, indent=4, sort_keys=True) + +# Function to add parser arguments for command-line interface +def add_parser_arguments(parser): + # Dataset and model settings + parser.add_argument('--data', type=str, default=f"{data_dir}/dataset_hub/", help='Path to dataset') # Path to the dataset + parser.add_argument('--model_dir', type=str, default="EdgeFLite", help='Directory to save the model') # Directory where the model is saved + parser.add_argument('--arch', type=str, default='wide_resnet16_8', choices=[ + 'resnet110', 'resnet_model_110sl', 'wide_resnet16_8', 'wide_resnetsl16_8', + 'wide_resnet_model_50_2', 'wide_resnetsl50_2'], help='Neural architecture name') # Neural architecture options + + # Normalization and training settings + parser.add_argument('--norm_mode', type=str, default='batch', choices=['batch', 'group', 'layer', 'instance', 'none'], help='Batch normalization style') # Type of normalization used + parser.add_argument('--is_syncbn', default=0, type=int, help='Use nn.SyncBatchNorm or not') # Whether to use synchronized batch normalization + parser.add_argument('--workers', default=16, type=int, help='Number of data loading workers') # Number of workers for data loading + parser.add_argument('--epochs', default=650, type=int, help='Total epochs to run') # Total number of training epochs + parser.add_argument('--start_epoch', default=0, type=int, help='Manual epoch number for restarts') # Starting epoch number for restarting training + parser.add_argument('--eval_per_epoch', default=1, type=int, help='Evaluation frequency per epoch') # Frequency of evaluation during training + parser.add_argument('--spid', default="EdgeFLite", type=str, help='Experiment name') # Name of the experiment + parser.add_argument('--save_weight', default=False, type=bool, help='Save model weights') # Whether to save model weights + + # Data augmentation settings + parser.add_argument('--batch_size', default=128, type=int, help='Mini-batch size for training') # Batch size for training + parser.add_argument('--eval_batch_size', default=100, type=int, help='Mini-batch size for evaluation') # Batch size for evaluation + parser.add_argument('--crop_size', default=32, type=int, help='Crop size for images') # Size of the image crops + parser.add_argument('--output_stride', default=8, type=int, help='Output stride for model') # Output stride for the model + parser.add_argument('--padding', default=4, type=int, help='Padding size for images') # Padding size for image processing + + # Learning rate settings + parser.add_argument('--lr_mode', type=str, default='cos', choices=['cos', 'step', 'poly', 'HTD', 'exponential'], help='Learning rate strategy') # Strategy for adjusting learning rate + parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='Initial learning rate') # Initial learning rate value + parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'AdamW', 'RMSprop', 'RMSpropTF'], help='Optimizer choice') # Choice of optimizer + parser.add_argument('--lr_milestones', nargs='+', type=int, default=[100, 200], help='Epochs for learning rate steps') # Epochs where learning rate adjustments occur + parser.add_argument('--lr_step_multiplier', default=0.1, type=float, help='Multiplier at learning rate milestones') # Multiplier applied at learning rate steps + parser.add_argument('--end_lr', type=float, default=1e-4, help='Ending learning rate') # Final learning rate value + + # Additional hyperparameters + parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for regularization') # Weight decay for L2 regularization + parser.add_argument('--momentum', default=0.9, type=float, help='Optimizer momentum') # Momentum for optimizers like SGD + parser.add_argument('--print_freq', default=20, type=int, help='Print frequency for logging') # Frequency for printing logs during training + + # Federated learning settings + parser.add_argument('--is_fed', default=1, type=int, help='Enable federated learning') # Enable or disable federated learning + parser.add_argument('--num_clusters', default=20, type=int, help='Number of clusters for federated learning') # Number of clusters in federated learning + parser.add_argument('--num_selected', default=20, type=int, help='Number of clients selected for training per round') # Number of clients selected each round + parser.add_argument('--num_rounds', default=300, type=int, help='Total number of training rounds') # Total number of federated learning rounds + + # Processing and decentralized training settings + parser.add_argument('--gpu', default=None, type=int, help='GPU ID to use') # GPU ID to be used for training + parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training') # Whether to disable CUDA + parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs for training') # Comma-separated GPU IDs for multi-GPU training + + # Parse command-line arguments + args = parser.parse_args() + + # Additional configurations + args.cuda = not args.no_cuda and torch.cuda.is_available() # Enable CUDA if not disabled and available + if args.cuda: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] # Parse GPU IDs from comma-separated string + args.num_gpus = len(args.gpu_ids) # Count number of GPUs being used + + return args diff --git a/EdgeFLite/data_collection/.DS_Store b/EdgeFLite/data_collection/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..43eeefef32d280748b8fc8c7060a4d77c94c68ea GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8O(;SS3LZQxc&(U-t%#RU>kAmsgGx-AV1qF$O=}OOkSEX=@;Q7S zXLh#*X}yU^nSt4FGe4W`w`Hfp7~|b>>@wD5j9H+F6&r>Xf_~H`sTfNYkgGNNY`_v0 zuppW&MU&w_GJt1iSG+`-jKKR5y<}-Ia-2_Isn#|&>sH-rSoi)wPW(|g z8W-Jgc#BV0QfAS79!A$`G6|a7=Q1CKX`T#KLK-I!a(9#Fv7B_}IFAdJ>uZN)TXxXg zn@-)%iQ5tf$4-0J64N8M({2gZbxvn9+uGSbJi82@vgcgBX-YZ#`jl)LoWnazR_rb= z`dI{Di(xC=ul_L4Wc~==`N9}ULJSZC#K77xU@kYSu{JxRjSvIGz;7{t_k#dM^bD37 z)ztxmY5@RiU{(Ti>?JTq81xL58o>j?bt<4v<)+2pIvwo7#CZlwjXIri(|mCAW^OtZ zuHFvo3l+||r;$ctfEf7B0QP<$O|1VHKllGChz4SS7+6mRc&+F6y0AA>TbFi;wN`*W rfTCbrsd1446IF^K7EAFis1mRXoB(= (5, 2): + return img.rotate_image(degrees, **kwargs) # Use rotate_image if PIL version is >= 5.2 + elif _PIL_VER >= (5, 0): + # Manually rotate_image the image for older versions of PIL + w, h = img.size + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), round(math.sin(angle), 15), 0.0, + round(-math.sin(angle), 15), round(math.cos(angle), 15), 0.0, + ] + + def apply_transformation(x, y, matrix): + return matrix[0] * x + matrix[1] * y + matrix[2], matrix[3] * x + matrix[4] * y + matrix[5] + + matrix[2], matrix[5] = apply_transformation(-rotn_center[0], -rotn_center[1], matrix) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.apply_transformation(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate_image(degrees, resample=kwargs['resample']) + +# Auto contrast image +def apply_auto_contrast(img, **kwargs): + return ImageOps.autocontrast(img) + +# Invert image colors +def invert(img, **kwargs): + return ImageOps.invert(img) + +# Equalize image histogram +def equalize(img, **kwargs): + return ImageOps.equalize(img) + +# Apply solarization effect +def apply_solarize(img, thresh, **kwargs): + return ImageOps.apply_solarize(img, thresh) + +# Apply solarization effect with an additional value +def apply_apply_solarize_addition(img, add, thresh=128, **kwargs): + lut = [min(255, i + add) if i < thresh else i for i in range(256)] + if img.mode in ("L", "RGB"): + lut = lut + lut + lut if img.mode == "RGB" else lut + return img.point(lut) + else: + return img + +# apply_posterization image (reduce color depth) +def apply_posterization(img, bits_to_keep, **kwargs): + return img if bits_to_keep >= 8 else ImageOps.apply_posterization(img, bits_to_keep) + +# Adjust image contrast +def contrast(img, factor, **kwargs): + return ImageEnhance.Contrast(img).enhance(factor) + +# Adjust image color +def color(img, factor, **kwargs): + return ImageEnhance.Color(img).enhance(factor) + +# Adjust image brightness +def brightness(img, factor, **kwargs): + return ImageEnhance.Brightness(img).enhance(factor) + +# Adjust image adjust_image_sharpness +def adjust_image_sharpness(img, factor, **kwargs): + return ImageEnhance.adjust_image_sharpness(img).enhance(factor) + +# Randomly negate a value with a 50% probability +def _apply_random_negation(v): + """With 50% probability, negate the value.""" + return -v if random.random() > 0.5 else v + +# Convert augmentation level to argument value +def _map_level_to_argument(level, max_value, hparams): + level = (level / _MAX_LEVEL) * max_value + return _apply_random_negation(level), + +# Convert translation level to argument value +def _map_absolute_map_level_to_argument(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + return _apply_random_negation(level), + +# Convert enhancement level to argument value +def _enhance_map_level_to_argument(level, _hparams): + return (level / _MAX_LEVEL) * 1.8 + 0.1, + +# Mapping of augmentation levels to argument converters +map_level_to_argument = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'rotate_image': lambda level, _: _map_level_to_argument(level, 30, None), + 'apply_posterization': lambda level, _: int((level / _MAX_LEVEL) * 4), + 'apply_solarize': lambda level, _: int((level / _MAX_LEVEL) * 256), + 'Color': _enhance_map_level_to_argument, + 'Contrast': _enhance_map_level_to_argument, + 'Brightness': _enhance_map_level_to_argument, + 'adjust_image_sharpness': _enhance_map_level_to_argument, + 'ShearX': lambda level, _: _map_level_to_argument(level, 0.3, None), + 'ShearY': lambda level, _: _map_level_to_argument(level, 0.3, None), + 'TranslateX': _map_absolute_map_level_to_argument, + 'TranslateY': _map_absolute_map_level_to_argument, +} + +# Mapping of augmentation names to functions +NAME_TO_OP = { + 'AutoContrast': apply_auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'rotate_image': rotate_image, + 'apply_posterization': apply_posterization, + 'apply_solarize': apply_solarize, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'adjust_image_sharpness': adjust_image_sharpness, + 'ShearX': apply_apply_shear_x_axis_axis, + 'ShearY': shear_y, + 'TranslateX': translate_image_x_absolute, + 'TranslateY': translate_image_y_absolute, +} + +# Class for applying augmentations to an image +class AugmentOp: + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] # Get the augmentation function + self.level_fn = map_level_to_argument[name] # Get the level function + self.prob = prob # Probability of applying the augmentation + self.magnitude = magnitude # Magnitude of the augmentation + self.hparams = hparams.copy() + self.kwargs = { + 'fillcolor': hparams.get('img_mean', _FILL), # Set the fill color + ' diff --git a/EdgeFLite/data_collection/cifar100_noniid.py b/EdgeFLite/data_collection/cifar100_noniid.py new file mode 100644 index 0000000..5a64d68 --- /dev/null +++ b/EdgeFLite/data_collection/cifar100_noniid.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +#### Get CIFAR-100 dataset in X and Y form +import torchvision +import numpy as np +import random +import torch +from torchvision import apply_transformations +from torch.utils.data import DataLoader, Dataset +from .cifar10_non_iid import * + +# Set random seeds for reproducibility +np.random.seed(68) +random.seed(68) + +def get_cifar100(data_dir): + ''' + Load and return CIFAR-100 train/test data and labels as numpy arrays. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved. + + Returns: + x_train (ndarray): Training data. + y_train (ndarray): Training labels. + x_test (ndarray): Test data. + y_test (ndarray): Test labels. + ''' + # Download CIFAR-100 training and test datasets + data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True) + data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True) + + # Transpose data for proper channel order and convert labels to numpy arrays + x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) + x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) + + return x_train, y_train, x_test, y_test + +def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True): + ''' + Splits data and labels among n_clients to simulate a non-IID distribution. + + Parameters: + data (ndarray): Dataset images [n_data x shape]. + labels (ndarray): Dataset labels [n_data]. + n_clients (int): Number of clients to split the data among. + verbose (bool): Print detailed information if True. + + Returns: + clients_split (ndarray): Split data and labels for each client. + ''' + n_labels = np.max(labels) + 1 # Number of unique labels/classes + + def divide_into_sections(n, m): + '''Return m random integers that sum up to n.''' + result = [1] * m + for _ in range(n - m): + result[random.randint(0, m - 1)] += 1 + return result + + # Shuffle and partition classes + n_classes = len(set(labels)) # Number of unique classes + classes = list(range(n_classes)) + np.random.shuffle(classes) # Shuffle class indices + label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels + + # Define number of classes for each client (randomized) + tmp = [np.random.randint(1, 100) for _ in range(n_clients)] + total_partition = sum(tmp) + class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly + + # Split class indices among clients + class_partition = sorted(class_partition, reverse=True) + class_partition_split = {} + + for idx, class_ in enumerate(classes): + # Split each class' indices according to the partition + class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])] + + clients_split = [] + for i in range(n_clients): + n = tmp[i] # Number of classes for this client + indices = [] + j = 0 + + # Assign class data to the client + while n > 0: + class_ = classes[j] + if class_partition_split[class_]: + indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client + n -= 1 + j += 1 + + clients_split.append([data[indices], labels[indices]]) # Add client's data split + + # Re-sort classes based on available data to balance further splits + classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True) + + # Raise error if client partition criteria cannot be met + if n > 0: + raise ValueError("Unable to fulfill the client partition criteria.") + + # Verbose option to print split information + if verbose: + display_data_split(clients_split) + + return np.array(clients_split) + +def display_data_split(clients_split): + '''Print the split information of the dataset for each client.''' + print("Data split:") + for i, client in enumerate(clients_split): + split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1) + print(f" - Client {i}: {split}") + print() + +def get_default_data_apply_transformations_cf100(train=True, verbose=True): + ''' + Return default data apply_transformationations for CIFAR-100. + + Parameters: + train (bool): Whether to apply apply_transformationations for training data. + verbose (bool): Print apply_transformationation details if True. + + Returns: + apply_transformations_train (Compose): Training apply_transformationations. + apply_transformations_eval (Compose): Evaluation (test) apply_transformationations. + ''' + # Define apply_transformationations for training data + apply_transformations_train = { + 'cifar100': apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + } + + # Define apply_transformationations for test data + apply_transformations_eval = { + 'cifar100': apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + } + + # Verbose option to print apply_transformationation steps + if verbose: + print("\nData preprocessing:") + for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations: + print(f' - {apply_transformationation}') + print() + + return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100'] + +def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True, + apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1): + ''' + Return data loaders for training on CIFAR-100. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be saved. + n_clients (int): Number of clients for splitting the dataset. + batch_size (int): Batch size for each data loader. + classes_per_client (int): Number of classes per client. + verbose (bool): Print detailed information if True. + apply_transformations_train (Compose): apply_transformationations for training data. + apply_transformations_eval (Compose): apply_transformationations for evaluation data. + non_iid (str): Strategy to create a non-IID dataset split. + split_factor (float): Factor to control the degree of splitting. + + Returns: + client_loaders (list): Data loaders for each client. + ''' + x_train, y_train, _, _ = get_cifar100(data_dir) + + # Verbose option to print dataset statistics + if verbose: + print_image_data_stats_train(x_train, y_train) + + # Split data according to non-IID strategy (e.g., quantity_skew) + split = None + if non_iid == 'quantity_skew': + split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose) + + split_tmp = shuffle_list(split) + + # Create DataLoaders for each client + client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), + batch_size=batch_size, shuffle=True) for x, y in split_tmp] + + return client_loaders + +def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None): + ''' + Return data loaders for testing on CIFAR-100. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be saved. + batch_size (int): Batch size for the test data loader. + verbose (bool): Print detailed information if True. + apply_transformations_eval (Compose): apply_transformationations for evaluation data. + + Returns: + test_loader (DataLoader): Test data loader. + ''' + _, _, x_test, y_test = get_cifar100(data_dir) + + # Verbose option to print dataset statistics + if verbose: + print_image_data_stats_test(x_test, y_test) + + # Create DataLoader for the test dataset + test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1), + batch_size=100, shuffle=False) + + return test_loader diff --git a/EdgeFLite/data_collection/cifar10_noniid.py b/EdgeFLite/data_collection/cifar10_noniid.py new file mode 100644 index 0000000..28da974 --- /dev/null +++ b/EdgeFLite/data_collection/cifar10_noniid.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +#### Load CIFAR-10 dataset and preprocess it +import torchvision +import numpy as np +import random +import torch +from torchvision import apply_transformations +from torch.utils.data import DataLoader, Dataset + +# Set random seed for reproducibility +np.random.seed(68) # Ensures that the random operations have consistent outputs +random.seed(68) + +def get_cifar10(data_dir): + """Return CIFAR-10 train/test data and labels as numpy arrays""" + # Download CIFAR-10 dataset + data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True) + data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True) + + # Preprocess the train and test data to the correct format (channels first) + x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) + x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) + + return x_train, y_train, x_test, y_test + +def display_data_statistics(data, labels, dataset_type): + """Print statistics of the dataset""" + print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], " + f"Labels: {np.min(labels)},..,{np.max(labels)}") + +def randomize_client_distributiony(train_len, n_clients): + """ + Distribute data among clients with a random distribution + Returns a list with the number of samples for each client + """ + # Randomly assign a number of samples to each client, ensuring the total matches the train_len + client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)] + total = sum(client_sizes) + client_sizes = np.array(client_sizes) + client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len + client_distributions = list(client_distributions) + client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated + return client_distributions + +def divide_into_sections(n, m): + """Return 'm' random integers that sum to 'n'""" + # Break the number 'n' into 'm' random parts that sum to 'n' + partitions = [1] * m + for _ in range(n - m): + partitions[random.randint(0, m - 1)] += 1 + return partitions + +def split_data_real_world_scenario(data, labels, n_clients=100): + """Split data among clients simulating real-world non-IID distribution""" + n_classes = len(set(labels)) # Determine number of unique classes + class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class + + client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client + total_partitions = sum(client_classes) + + class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute + class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)} + + clients_split = [] + for client in client_classes: + selected_indices = [] + for class_ in range(n_classes): + if class_partition_split[class_]: + selected_indices.extend(class_partition_split[class_].pop()) + client -= 1 + if client <= 0: + break + clients_split.append([data[selected_indices], labels[selected_indices]]) + + return np.array(clients_split) + +def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True): + """Split data among clients with IID (Independent and Identically Distributed) distribution""" + data_per_client = randomize_client_distributiony(len(data), n_clients) + label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)] + + if shuffle: + for indices in label_indices: + np.random.shuffle(indices) + + clients_split = [] + for client_data in data_per_client: + client_indices = [] + class_ = np.random.randint(len(label_indices)) + while client_data > 0: + take = min(client_data, len(label_indices[class_])) + client_indices.extend(label_indices[class_][:take]) + label_indices[class_] = label_indices[class_][take:] + client_data -= take + class_ = (class_ + 1) % len(label_indices) + + clients_split.append([data[client_indices], labels[client_indices]]) + + return np.array(clients_split) + +def randomize_data_order(data): + """Shuffle data while maintaining the mapping between inputs and labels""" + for i in range(len(data)): + index = np.arange(len(data[i][0])) + np.random.shuffle(index) + data[i][0], data[i][1] = data[i][0][index], data[i][1][index] + return data + +class CustomImageDataset(Dataset): + """Custom Dataset class for image data""" + def __init__(self, inputs, labels, apply_transformations=None, split_factor=1): + # Convert input data to torch tensors and apply apply_transformationations if provided + self.inputs = torch.Tensor(inputs) + self.labels = labels + self.apply_transformations = apply_transformations + self.split_factor = split_factor + + def __getitem__(self, index): + img, label = self.inputs[index], self.labels[index] + # Apply apply_transformationations to the image multiple times if split_factor > 1 + imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img] + return torch.cat(imgs, dim=0), label + + def __len__(self): + return len(self.inputs) + +def get_default_apply_transformations(verbose=True): + """Return default apply_transformationations for training and evaluation""" + apply_transformations_train = apply_transformations.Compose([ + apply_transformations.ToPILImage(), # Convert numpy array to PIL image + apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding + apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally + apply_transformations.ToTensor(), # Convert image to tensor + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std + ]) + + apply_transformations_eval = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation + ]) + + if verbose: + print("\nData preprocessing steps:") + for apply_transformationation in apply_transformations_train.apply_transformations: + print(f" - {apply_transformationation}") + + return apply_transformations_train, apply_transformations_eval + +def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1): + """Return DataLoader objects for clients with either IID or non-IID data split""" + x_train, y_train, _, _ = get_cifar10(data_dir) + display_data_statistics(x_train, y_train, "Train") + + # Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew') + if non_iid == 'quantity_skew': + clients_data = split_data_real_world_scenario(x_train, y_train, n_clients) + elif non_iid == 'label_skew': + clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client) + + shuffled_clients_data = randomize_data_order(clients_data) + + apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False) + client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), + batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data] + + return client_loaders + +def get_test_data_loader(data_dir, batch_size): + """Return DataLoader for test data""" + _, _, x_test, y_test = get_cifar10(data_dir) + display_data_statistics(x_test, y_test, "Test") + + _, apply_transformations_eval = get_default_apply_transformations(verbose=False) + test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False) + + return test_loader diff --git a/EdgeFLite/data_collection/data_cutout.py b/EdgeFLite/data_collection/data_cutout.py new file mode 100644 index 0000000..3a81219 --- /dev/null +++ b/EdgeFLite/data_collection/data_cutout.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import numpy as np + + +class Cutout: + """Applies random cutout augmentation by masking patches in an image. + + This technique randomly cuts out square patches from the image to + augment the dataset, helping the model become invariant to occlusions. + + Args: + n_holes (int): Number of patches to remove from the image. + length (int): Side length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + """ + Initializes the Cutout class with the number of patches to be removed + and the size of each patch. + + Args: + n_holes (int): Number of patches (holes) to cut out from the image. + length (int): Size of each square patch. + """ + self.n_holes = n_holes # Number of holes (patches) to remove. + self.length = length # Side length of each square patch. + + def __call__(self, img): + """ + Applies the cutout augmentation on the input image. + + Args: + img (Tensor): The input image tensor with shape (C, H, W), + where C is the number of channels, H is the height, + and W is the width of the image. + + Returns: + Tensor: The augmented image tensor with `n_holes` patches of size + `length x length` cut out, filled with zeros. + """ + # Get the height and width of the image (ignoring the channel dimension) + height, width = img.size(1), img.size(2) + + # Create a mask initialized with ones, same height and width as the image + # (each pixel is set to 1, representing no masking initially) + mask = np.ones((height, width), dtype=np.float32) + + # Randomly remove `n_holes` patches from the image + for _ in range(self.n_holes): + # Randomly choose the center of a patch (x_center, y_center) + y_center = np.random.randint(height) + x_center = np.random.randint(width) + + # Define the coordinates of the patch based on the center + # and ensure the patch stays within the image boundaries. + y1 = np.clip(y_center - self.length // 2, 0, height) + y2 = np.clip(y_center + self.length // 2, 0, height) + x1 = np.clip(x_center - self.length // 2, 0, width) + x2 = np.clip(x_center + self.length // 2, 0, width) + + # Set the mask to 0 for the patch (mark the patch as cut out) + mask[y1:y2, x1:x2] = 0.0 + + # Convert the mask from numpy array to a PyTorch tensor + mask_tensor = torch.from_numpy(mask).expand_as(img) + + # Multiply the input image by the mask (cut out the selected patches) + return img * mask_tensor diff --git a/EdgeFLite/data_collection/dataset_cifar.py b/EdgeFLite/data_collection/dataset_cifar.py new file mode 100644 index 0000000..154f2d6 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_cifar.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries +from PIL import Image # For image handling +import os # For file path operations +import numpy as np # For numerical operations +import pickle # For loading serialized data +import torch # For PyTorch operations + +# Import custom classes and functions from the current package +from .vision import VisionDataset +from .utils import validate_integrity, fetch_and_extract_archive + +# CIFAR10 dataset class +class CIFAR10(VisionDataset): + """ + CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations. + + Args: + root (str): Directory where the dataset is stored or will be downloaded to. + train (bool, optional): If True, load the training set. Otherwise, load the test set. + apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version. + target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it. + download (bool, optional): If True, download the dataset if it's not found locally. + split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1. + """ + # Directory and URL details for downloading the CIFAR-10 dataset + base_folder = 'cifar-10-batches-py' + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity + + # List of training batches with their corresponding MD5 checksums + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'] + ] + + # List of test batches with their corresponding MD5 checksums + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'] + ] + + # Info map to hold label names and their checksum + info_map = { + 'filename': 'batches.info_map', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888' + } + + # Initialization method + def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1): + super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) + self.train = train # Whether to load the training set or test set + self.split_factor = split_factor # Number of apply_transformationations to apply + + # Download dataset if necessary + if download: + self.download() + + # Check if the dataset is already downloaded and valid + if not self._validate_integrity(): + raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.') + + # Load the dataset + self.data, self.targets = self._load_data() + + # Load the label info map (to get class names) + self._load_info_map() + + # Load dataset from the files + def _load_data(self): + data, targets = [], [] # Initialize lists to hold data and labels + files = self.train_list if self.train else self.test_list # Choose train or test files + + # Load each file, deserialize with pickle, and append data and labels + for file_name, _ in files: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: + entry = pickle.load(f, encoding='latin1') # Load file + data.append(entry['data']) # Append image data + targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels + + # Reshape and format the data to (num_samples, height, width, channels) + data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format + return data, targets + + # Load label names (info map) + def _load_info_map(self): + info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map + if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map + raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.') + + # Load the label names + with open(info_map_path, 'rb') as info_map_file: + info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names + self.classes = info_map_data[self.info_map['key']] # Extract class labels + self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices + + # Get item (image and target) by index + def __getitem__(self, index): + """ + Get the item (image, target) at the specified index. + Args: + index (int): Index of the data. + + Returns: + tuple: apply_transformationed image and the target class. + """ + img, target = self.data[index], self.targets[index] # Get image and target label + img = Image.fromarray(img) # Convert numpy array to PIL image + + # Apply the apply_transformation multiple times based on split_factor + imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None + if imgs is None: + raise NotImplementedError('apply_transformation must be provided.') + + # Apply target apply_transformationation if available + if self.target_apply_transformation: + target = self.target_apply_transformation(target) + + return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target + + # Return the number of items in the dataset + def __len__(self): + return len(self.data) + + # Check if the dataset files are valid and downloaded + def _validate_integrity(self): + files = self.train_list + self.test_list # All files to check + for file_name, md5 in files: + file_path = os.path.join(self.root, self.base_folder, file_name) + if not validate_integrity(file_path, md5): # Verify integrity using MD5 + return False + return True + + # Download the dataset if it's not available + def download(self): + if self._validate_integrity(): + print('Files already downloaded and verified') + else: + fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + + # Representation string to include the split type (Train/Test) + def extra_repr(self): + return f"Split: {'Train' if self.train else 'Test'}" + + +# CIFAR100 is a subclass of CIFAR10, with minor modifications +class CIFAR100(CIFAR10): + """ + CIFAR100 Dataset, a subclass of CIFAR10. + """ + # Directory and URL details for downloading CIFAR-100 dataset + base_folder = 'cifar-100-vision' + url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz" + filename = "cifar-100-vision.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum + + # Training and test lists with their corresponding MD5 checksums for CIFAR-100 + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'] + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'] + ] + + # Info map to hold fine label names and their checksum + info_map = { + 'filename': 'info_map', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48' + } diff --git a/EdgeFLite/data_collection/dataset_factory.py b/EdgeFLite/data_collection/dataset_factory.py new file mode 100644 index 0000000..c23a396 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_factory.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +from torchvision import apply_transformations +from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets +from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy + +__all__ = ['obtain_data_loader'] # Define the public API of this module + +def obtain_data_loader( + data_dir, # Directory where the data is stored + split_factor=1, # Used for data partitioning, especially in federated learning + batch_size=128, # Batch size for loading data + crop_size=32, # Size to crop the input images + dataset='cifar10', # Dataset to use (CIFAR-10 by default) + split="train", # The split type: 'train', 'val', or 'test' + is_decentralized=False, # Whether to use decentralized training + is_autoaugment=1, # Use AutoAugment or not + randaa=None, # Placeholder for randomized augmentations + is_cutout=True, # Whether to apply cutout (random erasing) + erase_p=0.5, # Probability of applying random erasing + num_workers=8, # Number of workers to load data + pin_memory=True, # Use pinned memory for better GPU transfer + is_fed=False, # Whether to use federated learning + num_clusters=20, # Number of clients in federated learning + cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset + cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset +): + """Get the dataset loader""" + assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together + + # Loader settings based on multiprocessing + kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} + assert split in ['train', 'val', 'test'] # Ensure valid split + + # For CIFAR-10 dataset + if dataset == 'cifar10': + # Handle non-IID 'quantity skew' case for CIFAR-10 + if cifar10_non_iid == 'quantity_skew': + non_iid = 'quantity_skew' + # If in training split + if 'train' in split: + print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") + traindir = data_dir # Set data directory + # Define data apply_transformationations for training + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + CIFAR10Policy(), # AutoAugment policy + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), + ]) + train_sampler = None + print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...') + + # For federated learning, create loaders for each client + if is_fed: + train_loader = obtain_data_loaders_train( + traindir, + nclients=num_clusters * split_factor, # Number of clients in federated learning + batch_size=batch_size, + verbose=True, + apply_transformations_train=train_apply_transformation, + non_iid=non_iid, # Specify non-IID type + split_factor=split_factor + ) + else: + assert is_fed # Ensure that is_fed is True + return train_loader, train_sampler + else: + # If in validation or test split + valdir = data_dir # Set validation data directory + # Define data apply_transformationations for validation/testing + val_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + ]) + # Create the test loader + val_loader = obtain_data_loaders_test( + valdir, + nclients=num_clusters * split_factor, # Number of clients in federated learning + batch_size=batch_size, + verbose=True, + apply_transformations_eval=val_apply_transformation, + non_iid=non_iid, + split_factor=1 + ) + return val_loader + else: + # For standard IID CIFAR-10 case + if 'train' in split: + print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") + traindir = data_dir # Set training data directory + # Define data apply_transformationations for training + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + CIFAR10Policy(), # AutoAugment policy + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), + ]) + # Create the CIFAR-10 dataset object + train_dataset = CIFAR10( + traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor + ) + train_sampler = None # No sampler by default + + # Decentralized training setup + if is_decentralized: + train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True) + + print('INFO:PyTorch: creating CIFAR10 train dataloader...') + if is_fed: + # Federated learning setup + images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor)) + print(f"Images per client: {images_per_client}") + data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)] + data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1)) + # Split dataset for each client + traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68)) + # Create data loaders for each client + train_loader = [torch.utils.data.DataLoader( + x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs + ) for x in traindata_split] + else: + # Standard data loader + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs + ) + return train_loader, train_sampler + else: + # For validation or test split + valdir = data_dir # Set validation data directory + # Define data apply_transformationations for validation/testing + val_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + ]) + # Create CIFAR-10 dataset object for validation + val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1) + print('INFO:PyTorch: creating CIFAR10 validation dataloader...') + # Create data loader for validation + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs) + return val_loader + # Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly. + else: + raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.") diff --git a/EdgeFLite/data_collection/dataset_imagenet.py b/EdgeFLite/data_collection/dataset_imagenet.py new file mode 100644 index 0000000..f26a4f5 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_imagenet.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import warnings +from contextlib import contextmanager +import os +import shutil +import tempfile +import torch +from .folder import ImageFolder +from .utils import validate_integrity, extract_archive, verify_str_arg + +# Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash) +ARCHIVE_info_map = { + 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), + 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), + 'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf') +} + +# File name where the information map (class info, wnid, etc.) is stored +info_map_FILE = "info_map.bin" + +class ImageNet(ImageFolder): + """`ImageNet `_ 2012 Classification Dataset. + + Args: + root (str): Root directory of the ImageNet Dataset. + split (str, optional): Dataset split, either ``train`` or ``val``. + apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image. + target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target. + loader (callable, optional): Function to load an image from its path. + + Attributes: + classes (list): List of class name tuples. + class_to_idx (dict): Mapping of class names to indices. + wnids (list): List of WordNet IDs. + wnid_to_idx (dict): Mapping of WordNet IDs to class indices. + imgs (list): List of image path and class index tuples. + targets (list): Class index values for each image in the dataset. + """ + + def __init__(self, root, split='train', download=None, **kwargs): + # Check if download flag is used, raise warnings since dataset is no longer publicly accessible + if download is True: + raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.") + elif download is False: + warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning) + + # Expand the root directory path + root = self.root = os.path.expanduser(root) + + # Validate the dataset split (should be either 'train' or 'val') + self.split = verify_str_arg(split, "split", ("train", "val")) + + # Parse dataset archives (train/val/devkit) and prepare the dataset + self.extract_archives() + + # Load WordNet ID to class mappings from the info_map file + wnid_to_classes = load_information_map_file(self.root)[0] + + # Initialize the ImageFolder with the split folder (train/val directory) + super().__init__(self.divide_folder_contents, **kwargs) + + # Set class-related attributes + self.root = root + self.wnids = self.classes + self.wnid_to_idx = self.class_to_idx + + # Update classes to human-readable names and adjust the class_to_idx mapping + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] + self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} + + def extract_archives(self): + # Check if the info_map file exists and is valid, otherwise parse the devkit archive + if not validate_integrity(os.path.join(self.root, info_map_FILE)): + extract_devkit_archive(self.root) + + # If the dataset folder (train/val) does not exist, extract the respective archive + if not os.path.isdir(self.divide_folder_contents): + if self.split == 'train': + process_train_archive(self.root) + elif self.split == 'val': + process_validation_archive(self.root) + + @property + def divide_folder_contents(self): + # Return the path of the folder containing the images (train/val) + return os.path.join(self.root, self.split) + + def extra_repr(self): + # Additional representation for the dataset object (showing the split) + return f"Split: {self.split}" + +def load_information_map_file(root, file=None): + # Load the info_map file from the root directory + file = os.path.join(root, file or info_map_FILE) + if validate_integrity(file): + return torch.load(file) + else: + raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.") + +def _validate_archive_file(root, file, md5): + # Verify if the archive file is present and its checksum matches + if not validate_integrity(os.path.join(root, file), md5): + raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.") + +def extract_devkit_archive(root, file=None): + """Extract and process the ImageNet 2012 devkit archive to generate info_map information. + + Args: + root (str): Root directory with the devkit archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'. + """ + import scipy.io as sio + + # Parse info_map.mat from the devkit, containing class and WordNet ID information + def read_info_map_mat_file(devkit_root): + info_map_path = os.path.join(devkit_root, "data", "info_map.mat") + info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets'] + info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0] + idcs, wnids, classes = zip(*info_map)[:3] + classes = [tuple(clss.split(', ')) for clss in classes] + return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)} + + # Parse the validation ground truth file for image class labels + def process_val_groundtruth_txt(devkit_root): + file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") + with open(file) as f: + return [int(line.strip()) for line in f] + + # Context manager to handle temporary directories for archive extraction + @contextmanager + def get_tmp_dir(): + tmp_dir = tempfile.mkdtemp() + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + + # Extract and process the devkit archive + file, md5 = ARCHIVE_info_map["devkit"] + _validate_archive_file(root, file, md5) + + with get_tmp_dir() as tmp_dir: + extract_archive(os.path.join(root, file), tmp_dir) + devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root) + val_idcs = process_val_groundtruth_txt(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + + # Save the mappings to the info_map file + torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE)) + +def process_train_archive(root, file=None, folder="train"): + """Extract and organize the ImageNet 2012 train dataset. + + Args: + root (str): Root directory containing the train dataset archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'. + folder (str, optional): Destination folder. Defaults to 'train'. + """ + file, md5 = ARCHIVE_info_map["train"] + _validate_archive_file(root, file, md5) + + train_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), train_root) + + # Extract each class-specific archive in the train dataset + for archive in os.listdir(train_root): + extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True) + +def process_validation_archive(root, file=None, wnids=None, folder="val"): + """Extract and organize the ImageNet 2012 validation dataset. + + Args: + root (str): Root directory containing the validation dataset archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'. + wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file). + folder (str, optional): Destination folder. Defaults to 'val'. + """ + file, md5 = ARCHIVE_info_map["val"] + if wnids is None: + wnids = load_information_map_file(root)[1] + + _validate_archive_file(root, file, md5) + + val_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), val_root) + + # Create directories for each WordNet ID (class) and move validation images into their respective folders + for wnid in set(wnids): + os.mkdir(os.path.join(val_root, wnid)) + + for wnid, img in zip(wnids, sorted(os diff --git a/EdgeFLite/data_collection/directory_utils.py b/EdgeFLite/data_collection/directory_utils.py new file mode 100644 index 0000000..e9d45cf --- /dev/null +++ b/EdgeFLite/data_collection/directory_utils.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary modules +from .vision import VisionDataset # Import the base VisionDataset class +from PIL import Image # Import PIL for image loading and processing +import os # For interacting with the file system +import torch # PyTorch for tensor operations + +# Function to check if a file has an allowed extension +def validate_file_extension(filename, extensions): + """ + Check if a file has an allowed extension. + + Args: + filename (str): Path to the file. + extensions (tuple of str): Extensions to consider (in lowercase). + + Returns: + bool: True if the filename ends with one of the given extensions. + """ + return filename.lower().endswith(extensions) + +# Function to check if a file is an image +def is_image_file(filename): + """ + Check if a file is an image based on its extension. + + Args: + filename (str): Path to the file. + + Returns: + bool: True if the filename is a known image format. + """ + return validate_file_extension(filename, IMG_EXTENSIONS) + +# Function to create a dataset of file paths and their corresponding class indices +def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None): + """ + Creates a list of file paths and their corresponding class indices. + + Args: + directory (str): Root directory. + class_to_idx (dict): Mapping of class names to class indices. + extensions (tuple, optional): Allowed file extensions. + is_valid_file (callable, optional): Function to validate files. + + Returns: + list: A list of (file_path, class_index) tuples. + """ + instances = [] + directory = os.path.expanduser(directory) # Expand user directory path if needed + + # Ensure only one of extensions or is_valid_file is specified + if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None): + raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.") + + # Define the validation function if extensions are provided + if extensions is not None: + def is_valid_file(x): + return validate_file_extension(x, extensions) + + # Iterate through the directory, searching for valid image files + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] # Get the class index + target_dir = os.path.join(directory, target_class) # Define the target class folder + if not os.path.isdir(target_dir): # Skip if it's not a directory + continue + # Walk through the directory and subdirectories + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) # Full file path + if is_valid_file(path): # Check if it's a valid file + instances.append((path, class_index)) # Append file path and class index to the list + + return instances # Return the dataset + +# DatasetFolder class: Generic data loader for samples arranged in subdirectories by class +class DatasetFolder(VisionDataset): + """ + A generic data loader where samples are arranged in subdirectories by class. + + Args: + root (str): Root directory path. + loader (callable): Function to load a sample from its file path. + extensions (tuple[str]): Allowed file extensions. + apply_transformation (callable, optional): apply_transformation applied to each sample. + target_apply_transformation (callable, optional): apply_transformation applied to each target. + is_valid_file (callable, optional): Function to validate files. + split_factor (int, optional): Number of times to apply the apply_transformation. + + Attributes: + classes (list): Sorted list of class names. + class_to_idx (dict): Mapping of class names to class indices. + samples (list): List of (sample_path, class_index) tuples. + targets (list): List of class indices corresponding to each sample. + """ + + def __init__(self, root, loader, extensions=None, apply_transformation=None, + target_apply_transformation=None, is_valid_file=None, split_factor=1): + super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) + self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory + self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files + + # Raise an error if no valid files are found + if len(self.samples) == 0: + raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. " + f"Supported extensions are: {','.join(extensions)}") + + self.loader = loader # Function to load a sample + self.extensions = extensions # Allowed file extensions + self.targets = [s[1] for s in self.samples] # List of target class indices + self.split_factor = split_factor # Number of apply_transformationations to apply + + # Function to find class subdirectories in the root directory + def _discover_classes(self, dir): + """ + Discover class subdirectories in the root directory. + + Args: + dir (str): Root directory. + + Returns: + tuple: (classes, class_to_idx) where classes are subdirectories of 'dir', + and class_to_idx is a mapping of class names to indices. + """ + classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes) + class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices + return classes, class_to_idx + + # Function to get a sample and its target by index + def __getitem__(self, index): + """ + Retrieve a sample and its target by index. + + Args: + index (int): Index of the sample. + + Returns: + tuple: (sample, target), where the sample is the apply_transformationed image and + the target is the class index. + """ + path, target = self.samples[index] # Get the file path and target class index + sample = self.loader(path) # Load the sample (image) + + # Apply apply_transformationation to the sample 'split_factor' times + imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError + + # Apply target apply_transformationation if specified + if self.target_apply_transformation: + target = self.target_apply_transformation(target) + + return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target + + # Return the number of samples in the dataset + def __len__(self): + return len(self.samples) + +# List of supported image file extensions +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + +# Function to load an image using PIL +def load_image_pil(path): + """ + Load an image from the given path using PIL. + + Args: + path (str): Path to the image. + + Returns: + Image: RGB image. + """ + with open(path, 'rb') as f: + img = Image.open(f) # Open the image file + return img.convert('RGB') # Convert the image to RGB format + +# Function to load an image using accimage library with fallback to PIL +def load_accimage(path): + """ + Load an image using the accimage library, falling back to PIL on failure. + + Args: + path (str): Path to the image. + + Returns: + Image: Image loaded with accimage or PIL. + """ + import accimage # accimage is a faster image loading library + try: + return accimage.Image(path) # Try loading with accimage + except IOError: + return load_image_pil(path) # Fall back to PIL on error + +# Function to load an image using the default backend (accimage or PIL) +def basic_loader(path): + """ + Load an image using the default image backend (accimage or PIL). + + Args: + path (str): Path to the image. + + Returns: + Image: Loaded image. + """ + from torchvision import get_image_backend # Get the default image backend + return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend + +# ImageFolder class: A dataset loader for images arranged in subdirectories by class +class ImageFolder(DatasetFolder): + """ + A dataset loader for images arranged in subdirectories by class. + + Args: + root (str): Root directory path. + apply_transformation (callable, optional): apply_transformation applied to each image. + target_apply_transformation (callable, optional): apply_transformation applied to each target. + loader (callable, optional): Function to load an image from its path. + is_valid_file (callable, optional): Function to validate files. + + Attributes: + classes (list): Sorted list of class names. + class_to_idx (dict): Mapping of class names to class indices. + imgs (list): List of (image_path, class_index) tuples. + """ + + def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1): + super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, + apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation, + is_valid_file=is diff --git a/EdgeFLite/data_collection/helper_utils.py b/EdgeFLite/data_collection/helper_utils.py new file mode 100644 index 0000000..169a823 --- /dev/null +++ b/EdgeFLite/data_collection/helper_utils.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import hashlib +import gzip +import tarfile +import zipfile +import urllib.request +from torch.utils.model_zoo import tqdm + +def generate_update_progress_barr(): + """Generates a progress bar for tracking download progress.""" + pbar = tqdm(total=None) + + def update_progress_bar(count, block_size, total_size): + """Updates the progress bar based on the downloaded data size.""" + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return update_progress_bar + +def compute_md5_checksum(fpath, chunk_size=1024 * 1024): + """Calculates the MD5 checksum for a given file.""" + md5 = hashlib.md5() + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + return md5.hexdigest() + +def verify_md5_checksum(fpath, md5): + """Checks if the MD5 of a file matches the given checksum.""" + return md5 == compute_md5_checksum(fpath) + +def validate_integrity(fpath, md5=None): + """Checks the integrity of a file by verifying its existence and MD5 checksum.""" + if not os.path.isfile(fpath): + return False + return md5 is None or verify_md5_checksum(fpath, md5) + +def download_url(url, root, filename=None, md5=None): + """Download a file from a URL and save it in the specified directory.""" + root = os.path.expanduser(root) + filename = filename or os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if validate_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + return + + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) + except (urllib.error.URLError, IOError) as e: + if url.startswith('https'): + url = url.replace('https:', 'http:') + print('Failed download. Retrying with http.') + urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) + else: + raise e + + if not validate_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + +def list_dir(root, prefix=False): + """List all directories at the specified root.""" + root = os.path.expanduser(root) + directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] + + return [os.path.join(root, d) for d in directories] if prefix else directories + +def list_files(root, suffix, prefix=False): + """List all files with a specific suffix in the specified root.""" + root = os.path.expanduser(root) + files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)] + + return [os.path.join(root, f) for f in files] if prefix else files + +def fetch_file_google_drive(file_id, root, filename=None, md5=None): + """Download a file from Google Drive and save it in the specified directory.""" + url = "https://docs.google.com/uc?export=download" + root = os.path.expanduser(root) + filename = filename or file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if os.path.isfile(fpath) and validate_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + return + + session = requests.Session() + response = session.get(url, params={'id': file_id}, stream=True) + token = _get_confirm_token(response) + + if token: + params = {'id': file_id, 'confirm': token} + response = session.get(url, params=params, stream=True) + + _store_response_content(response, fpath) + +def _get_confirm_token(response): + """Extract the download token from Google Drive cookies.""" + return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None) + +def _store_response_content(response, destination, chunk_size=32768): + """Save the response content to a file in chunks.""" + with open(destination, "wb") as f: + pbar = tqdm(total=None) + progress = 0 + for chunk in response.iter_content(chunk_size): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + progress += len(chunk) + pbar.update(progress - pbar.n) + pbar.close() + +def extract_archive(from_path, to_path=None, remove_finished=False): + """Extract an archive file (tar, zip, gz) to the specified path.""" + if to_path is None: + to_path = os.path.dirname(from_path) + + if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")): + mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else + '.xz' if from_path.endswith('.tar.xz') else '') + with tarfile.open(from_path, mode) as tar: + tar.extractall(path=to_path) + elif from_path.endswith(".gz"): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif from_path.endswith(".zip"): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Extraction of {} not supported".format(from_path)) + + if remove_finished: + os.remove(from_path) + +def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False): + """Download and extract an archive file from a URL.""" + download_root = os.path.expanduser(download_root) + extract_root = extract_root or download_root + filename = filename or os.path.basename(url) + + download_url(url, download_root, filename, md5) + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + +def iterable_to_str(iterable): + """Convert an iterable to a string representation.""" + return "'" + "', '".join(map(str, iterable)) + "'" + +def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): + """Verify that a string argument is valid and raise an error if not.""" + if not isinstance(value, str): + msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}." + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}." + raise ValueError(msg) + + return value diff --git a/EdgeFLite/data_collection/pill_data_base.py b/EdgeFLite/data_collection/pill_data_base.py new file mode 100644 index 0000000..46308c1 --- /dev/null +++ b/EdgeFLite/data_collection/pill_data_base.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +from PIL import Image +from torch.utils.data import DataLoader, Dataset +import torch +import os + +# Importing the HOME configuration +from config import HOME + +class PillDataBase(Dataset): + def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1): + """ + Initialize the dataset. + + Args: + data_dir (str): Directory where the dataset is stored. + train (bool): Flag to indicate if it's a training or testing dataset. + apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization). + split_factor (int): Number of times each image is split into parts for augmentation purposes. + """ + self.train = train + self.apply_transformation = apply_transformation + self.split_factor = split_factor + self.data_dir = data_dir + '/pill_base' + self.dataset = self._load_data() + + def __len__(self): + """Return the number of samples in the dataset.""" + return len(self.dataset) + + def _load_data(self): + """ + Load the dataset by reading the corresponding text file (train.txt or test.txt). + + The dataset text file contains the image file paths and corresponding labels. + + Returns: + dataset (list): List of image file paths and their respective labels. + """ + dataset = [] + txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt') + + with open(txt_path, 'r') as file: + lines = file.readlines() + for line in lines: + # Each line contains an image path and a label separated by space + filename, label = line.strip().split(' ') + # Adjust the image path to the correct directory structure + filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir) + # Append the image file path and label as an integer + dataset.append([filename, int(label)]) + + return dataset + + def __getitem__(self, index): + """ + Retrieve a specific sample from the dataset at the given index. + + Args: + index (int): Index of the image and label to retrieve. + + Returns: + tuple: A tensor of concatenated apply_transformationed images and the corresponding label. + """ + images = [] + image_path = self.dataset[index][0] + label = torch.tensor(int(self.dataset[index][1])) + + # Open the image file + image = Image.open(image_path) + + # Apply apply_transformationations to the image if provided and split into parts as specified by split_factor + if self.apply_transformation: + for _ in range(self.split_factor): + images.append(self.apply_transformation(image)) + + # Concatenate all apply_transformationed image splits into a single tensor + return torch.cat(images, dim=0), label + +if __name__ == "__main__": + # Example of how to instantiate and use the dataset + dataset = PillDataBase() diff --git a/EdgeFLite/data_collection/pill_data_large.py b/EdgeFLite/data_collection/pill_data_large.py new file mode 100644 index 0000000..f5fd1db --- /dev/null +++ b/EdgeFLite/data_collection/pill_data_large.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import glob +from PIL import Image +import torch +from torch.utils.data import Dataset + +# Define the folder paths for training and testing datasets +FOLDER_PATHS = [ + '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images', + '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images' +] + +# Custom dataset class inheriting from PyTorch's Dataset class +class PillDataLarge(Dataset): + def __init__(self, train=True, apply_transformation=None, split_factor=1): + """ + Initializes the dataset object. + + Args: + - train (bool): If True, load the training dataset, otherwise load the test dataset. + - apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample. + - split_factor (int): Number of times to apply the apply_transformationations to the image. + """ + self.train = train # Flag to determine if the dataset is for training or testing + self.apply_transformation = apply_transformation # apply_transformationation to apply to the images + self.split_factor = split_factor # Number of times to apply the apply_transformationation + self.dataset = self._load_data() # Load the dataset + + def __len__(self): + """ + Returns the total number of samples in the dataset. + """ + return len(self.dataset) + + def _load_data(self): + """ + Loads the data from the dataset folders. + + Returns: + - dataset (list): A list containing image file paths and their corresponding class IDs. + """ + folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path + class_names = sorted(os.listdir(folder_path)) # Get class names from folder + class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs + + dataset = [] + for class_name, class_id in class_map.items(): + folder_class = os.path.join(folder_path, class_name) # Path to class folder + files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files + for file_path in files_jpg: + dataset.append([file_path, class_id]) # Append file path and class ID to the dataset + + return dataset + + def __getitem__(self, index): + """ + Returns a sample and its corresponding label from the dataset. + + Args: + - index (int): Index of the sample. + + Returns: + - tuple: A tuple of the image tensor and the label tensor. + """ + Xs = [] # List to store apply_transformationed images + image_path = self.dataset[index][0] # Get image path from dataset + label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor + + X = Image.open(image_path) # Open the image using PIL + + if self.apply_transformation: + for _ in range(self.split_factor): + Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times + + return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label + +if __name__ == "__main__": + dataset = PillDataLarge() # Create an instance of the dataset + print(len(dataset)) # Print the size of the dataset + print(dataset[0]) # Print the first sample of the dataset diff --git a/EdgeFLite/data_collection/skin_dataset.py b/EdgeFLite/data_collection/skin_dataset.py new file mode 100644 index 0000000..3e0107c --- /dev/null +++ b/EdgeFLite/data_collection/skin_dataset.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries for image processing and handling datasets. +from PIL import Image # Used for opening and manipulating images. +from cv2 import split # A function from OpenCV, though it's not used here. It may have been intended for something else. +from torch.utils.data import DataLoader, Dataset # These are PyTorch utilities for managing datasets and data loading. +import torch # PyTorch library for tensor operations and deep learning. + +# Define a custom dataset class named 'SkinData' which inherits from PyTorch's Dataset class. +class SkinData(Dataset): + # Initialize the dataset with a DataFrame (df), an optional apply_transformationation (apply_transformation), and a split factor (split_factor). + def __init__(self, df, apply_transformation=None, split_factor=1): + self.df = df # Store the DataFrame containing image paths and target labels. + self.apply_transformation = apply_transformation # Optional image apply_transformationations to apply (e.g., resizing, normalizing). + self.split_factor = split_factor # A factor determining how many times to split or augment the image. + self.test_same_view = False # A flag indicating whether to return multiple augmentations of the same image. + + # Return the number of samples in the dataset, which corresponds to the number of rows in the DataFrame. + def __len__(self): + return len(self.df) + + # Retrieve the image and corresponding label at a specific index. + def __getitem__(self, index): + Xs = [] # Create an empty list to store apply_transformationed versions of the image. + + # Open the image located at the 'path' specified by the index in the DataFrame, then resize it to 64x64. + X = Image.open(self.df['path'][index]).resize((64, 64)) + + # Retrieve the target label (as a tensor) from the 'target' column of the DataFrame and convert it to a PyTorch tensor. + y = torch.tensor(int(self.df['target'][index])) + + # If 'test_same_view' is set to True, apply the same apply_transformationation multiple times and store the augmented images. + if self.test_same_view: + if self.apply_transformation: + aug = self.apply_transformation(X) # Apply the apply_transformationation once to the image. + # Store the same augmented image multiple times in the list 'Xs' (repeated 'split_factor' times). + Xs = [aug for _ in range(self.split_factor)] + else: + # If 'test_same_view' is False, apply the apply_transformationation independently to create different augmentations. + if self.apply_transformation: + # Store different augmentations of the image in the list 'Xs', each apply_transformationed independently. + Xs = [self.apply_transformation(X) for _ in range(self.split_factor)] + + # Concatenate the list of images into a single tensor along the first dimension (batch) and return it along with the label. + return torch.cat(Xs, dim=0), y diff --git a/EdgeFLite/data_collection/vision_utils.py b/EdgeFLite/data_collection/vision_utils.py new file mode 100644 index 0000000..84b32ee --- /dev/null +++ b/EdgeFLite/data_collection/vision_utils.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import torch +import torch.utils.data as data + +# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class. +# It handles the initialization and representation of a vision-related dataset, +# including optional apply_transformationation of input data and targets. +class VisionDataset(data.Dataset): + _repr_indent = 4 # Defines the indentation level for dataset representation + + def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None): + # Initializes the dataset by setting root directory and optional apply_transformationations + # If root is a string, expand any user directory shortcuts like "~" + self.root = os.path.expanduser(root) if isinstance(root, str) else root + + # Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both) + has_apply_transformations = apply_transformations is not None + has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None + + if has_apply_transformations and has_separate_apply_transformation: + raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.") + + # Set apply_transformationations + self.apply_transformation = apply_transformation + self.target_apply_transformation = target_apply_transformation + + # If separate apply_transformations are provided, wrap them in a StandardTransform + if has_separate_apply_transformation: + apply_transformations = StandardTransform(apply_transformation, target_apply_transformation) + self.apply_transformations = apply_transformations + + # Placeholder for the method to retrieve an item by index + def __getitem__(self, index): + raise NotImplementedError + + # Placeholder for the method to return dataset length + def __len__(self): + raise NotImplementedError + + # Representation of the dataset including number of datapoints, root directory, and apply_transformations + def __repr__(self): + head = f"Dataset {self.__class__.__name__}" + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() # Include any additional representation details + if hasattr(self, "apply_transformations") and self.apply_transformations is not None: + body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable + lines = [head] + [" " * self._repr_indent + line for line in body] + return '\n'.join(lines) + + # Utility to format the representation of the apply_transformation and target_apply_transformation attributes + def _format_apply_transformation_repr(self, apply_transformation, head): + lines = apply_transformation.__repr__().splitlines() + return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] + + # Hook for adding extra dataset-specific information in the representation + def extra_repr(self): + return "" + + +# StandardTransform class handles the application of the apply_transformation and target_apply_transformation +# during dataset iteration or data loading. +class StandardTransform: + def __init__(self, apply_transformation=None, target_apply_transformation=None): + # Initialize with optional input and target apply_transformationations + self.apply_transformation = apply_transformation + self.target_apply_transformation = target_apply_transformation + + # Calls the appropriate apply_transformations on the input and target when invoked + def __call__(self, input, target): + if self.apply_transformation is not None: + input = self.apply_transformation(input) + if self.target_apply_transformation is not None: + target = self.target_apply_transformation(target) + return input, target + + # Utility to format the apply_transformationation representation + def _format_apply_transformation_repr(self, apply_transformation, head): + lines = apply_transformation.__repr__().splitlines() + return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] + + # Representation of the StandardTransform including both input and target apply_transformationations + def __repr__(self): + body = [self.__class__.__name__] + if self.apply_transformation is not None: + body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ") + if self.target_apply_transformation is not None: + body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ") + + return '\n'.join(body) diff --git a/EdgeFLite/debug_tool.py b/EdgeFLite/debug_tool.py new file mode 100644 index 0000000..e925f1d --- /dev/null +++ b/EdgeFLite/debug_tool.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries +import torch # PyTorch for tensor computations and neural networks +from torch import nn # Neural network module +# "decentralized" is not a valid import in PyTorch, possibly a typo. Removed for now. + +# Check for available device (CPU or GPU) +# If a GPU is available (CUDA), the code will use it; otherwise, it falls back to CPU. +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Define normalization layer and the number of initial input channels for the convolutional layers +batch_norm_layer = nn.BatchNorm2d # 2D Batch Normalization to stabilize training +initial_channels = 32 # Number of channels for the first convolutional layer + +# Define the convolutional neural network (CNN) architecture using nn.Sequential +network = nn.Sequential( + # 1st convolutional layer: takes 3 input channels (RGB image), outputs 'initial_channels' feature maps + # Uses kernel size 3, stride 2 for downsampling, and padding 1 to maintain spatial dimensions + nn.Conv2d(in_channels=3, out_channels=initial_channels, kernel_size=3, stride=2, padding=1, bias=False), + batch_norm_layer(initial_channels), # Apply Batch Normalization to the output + nn.ReLU(inplace=True), # ReLU activation function to introduce non-linearity + + # 2nd convolutional layer: takes 'initial_channels' input, outputs the same number of feature maps + # No downsampling here (stride 1) + nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels, kernel_size=3, stride=1, padding=1, bias=False), + batch_norm_layer(initial_channels), # Batch normalization for better convergence + nn.ReLU(inplace=True), # ReLU activation + + # 3rd convolutional layer: doubles the number of output channels (for deeper features) + # Again, no downsampling (stride 1) + nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels * 2, kernel_size=3, stride=1, padding=1, bias=False), + batch_norm_layer(initial_channels * 2), # Batch normalization for the increased feature maps + nn.ReLU(inplace=True), # ReLU activation + + # Max pooling layer to further downsample the feature maps (reduces spatial dimensions) + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Pooling with kernel size 3 and stride 2 +) + +# Create a dummy input tensor simulating a batch of 128 images with 3 channels (RGB), each of size 64x64 +sample_input = torch.randn(128, 3, 64, 64) + +# Print the defined network architecture and the shape of the output after a forward pass +print(network) +# Perform a forward pass with the sample input and print the resulting output shape +print(network(sample_input).shape) diff --git a/EdgeFLite/fedml_service/.DS_Store b/EdgeFLite/fedml_service/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ade4daf47d7e3baf51f5770ba83948b88f6427dd GIT binary patch literal 6148 zcmeH~(QeZ)6o!wt!cxYhUNDIZE-Z0d#M;q|YBy0CFF+7n04m86Z4tqql9WnRmGT7R zg?J90hyU1)#OSiSA*g($ouOQ8MEeF$z!Uh_2=Ll{NHLXIwHNgJ`YkC-mepla=mar8lFK5m>Z7BdVrRE^ z@BUt}7wiWgqYJ%^>bPE1v$%ZEPj9p?lC~ZvZ}V(9A3S=gRUPLlD@{V4W(fK4E?22u z&h$d1mC224N6-)Y^TCtVYCIXg7|Lgp>0~{WtI_dvG?b(9^mx7Q2Zv9epS+%bF1{#z zWuXe}LBKd>?6J2AIjg8e%g-4E&s zcmn@30p1@ZoHcaXST$P*3UdX3_Rwq$b-qh*j&vG2ZLFdPCTuFyrV4k(5H=m{(#3^N z8>=>*gu8qQ_h#X4D8jrQ<4Y4xBDCsDPrwtnPGHwQ4*2|k)qVfJp5&iA0Z(A72#DTk zbUMS7+}XM@IX-J8{1cpw;~J|b1%*3~O@)u*$8clllU#v@P8+Lef%!iI27@m=fxk-N E7o;EZZ2$lO literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/.DS_Store b/EdgeFLite/fedml_service/architecture/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..81a948a774c891cf2238973f78693b2bbd85a406 GIT binary patch literal 6148 zcmeHK!AiqG5Z!I7O({YS3LZQxc&*s9_E5ZpSbxBX9#m@56b;6#G;MPzg?xejkl*3= zIJ3I}i#>P}u`{syW_M;bI}ftE!x-b#oN_SXUTUST6P5A*Eh>;AK8yZyl{)!No}-Ktwn>&}0axjzl2v&;)7H#oYKG6_pP z2(RLCKD2kwWI7GvbUaZ3aWn>(+v_-uWbVmX8f7Zh*9NO$HHP-VV$tb!PFmuq*Y7P` zV&QcAPD?nQes{TSSbK-Zrx(Mgtv!W$?mpC}t8A?yzw)x_MNq=`%) zU|a^5L4?ErF+dCu0~^JFxdyD}MrBRO5(C7*PZ+@SK|n)v4Hg>J(E*(*0RU?-tps$e zB``)9bPX07!2-f{D4-7Iro`Yn9Oi}ba}5?6bvWau_~7QtW%0t**I|93+!=Q@QcnyJ z1FH;R?F(Y#{eKR>%-SMfxll(85Ci{}0bU#U0}rw@ck8!2@varn?xCTeUx5k;=qrx^ j7~nawuZ-F*PzOKPV4)FL!Fd%9NDl!;2zA83FEH>4KUPhp literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..852bec4136b4294112584692d306d2735adf955f GIT binary patch literal 6148 zcmeHK%Sr@646NuL7QC#UJoYbm(DSgYU*HG$8UCYdj_lkEIEWGV@BkjojQo%L!$W_9R^(E6>oT$S?DjiboOIBqesg#@1B2T z6}Nc83;k_7+;?|<->HNXFK_UP5$CvI$&U)+P|CQQKwe0$_}C_V6RaIll7VC(8At|_ zfgfZ*ZxdyCR55ijkPIXPpA6`JC>+8jaH{B62bH$~#0A|hwAD)}rvx^EQ$@~Dtfmrc zszr-oHJ$zvaZTV-rwBGIIz+)p5Xup|MACPnxG1Sun>KVtdxQ zDb6tfTO77?UgocUHue$42SVKb_rth0I?SiiHy)Jsl=pOtr(Vc##`0(!XYu~uv$K>I$Lch7TtM& zi*i^`)F}m|z_|jxd(&n<{~zhU?EhXPlLAuUpDJLJ)yrzZm#f)2dQG3Tjs8kM8f&AR mM~q@(jAH)qQG9n&*Zi6JUN|HMo$;U(^+&*Ukx7BSP~ZzH86A%R literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..efe3d1daa9b18aca9a328229f52c1fdee415103e GIT binary patch literal 6148 zcmeHKyKciU3?y3*2D*6cs9(s?%tDZSL4TlVoy}-~q}@H`yYgu{K86}@wq!Am0Cf_@ zlb{<+Q$%$0+pk1A5ozIua<*Y^wr@VOr;I2Njx+Z1BX6?qPhZ0<`~QG(TRDPxlJCF# z&%9NW_1kfN@pRD|$dL+AfvEz| zvFxn=@8CE3|0#(pDnJDmrGPfu_w5Ex%Gx@6oYmR_|A1T01#X79Q!scr26{Qh!piZh cCq-VdIreMf80d7woet#BfayY`0*_YU3OipE&j0`b literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics new file mode 100644 index 0000000000000000000000000000000000000000..face2edbe01171a41f3728aa36ac8645b88c3981 GIT binary patch literal 15273 zcmaKzd7Mr4|Hp?hckE_i=H5YFsgR|iRHEjfqVFYxSrl?rBQ%%=^(}JRETx+8RhGWp z7F#JpmP+Mh$<7S2569TaF8MOU@AZDa&*z+Dn#b?^$M^BN@Avz8-pgk>cW$Lue79<5 z$;>DI=N~*b+1Em^*!5&M{(p3KSmRz%VPRoeIm3nzdwy7EzmXX^xr2rc&8~oeUNzv! z95j4*W=6k^p)U;TKQz1I_+C|E9RBaDjDER;Ud_nO4j-Lesd29e7;^^>%Nd>>5vbTJ z9G?aa9iHt8RO(fUo-u6bi`kKZN+U)fXJ6zTon4`EuPXGQjNIY@T#boJ-2 zX54q*eZ8tHNB`%aPZ^fgGCRDJ_V^A1*08d^*4dRhX=%&I6~?KC4;qq@9nndve0y)U zQDGTb!=4|AVk(RnogLYDATnK9Fa2`J``NkM2p zl>k*`=)$+JzJ)mb0nZ%HAE8B zs&XF80ID-^zR%J}rcIkpK(|$z4J6t?Sllo&BTFQp8UlU4^QKi=awmzTEUkhF3>%0n z;VqKL45%`on$pi}EoWP$CD--C%aiL>S7Qmt_jRlZ4WO1}shgXrfcXYTQI=+YmIQoo zOov96s(_*xI$QCv&eo{ACzDsQ^f6fi^3KRkB&#P-4Nx>gr>=QVkPSoVQA6WM$~3}c zNtt#}3_$~`4k(79lU?Fc$VSlpFTD^Xt0zzcP%J|yo=BN*_0;2clOrk9 zM-<4D^X4qkNqPb`0o9gF-J6A5mmrO^)my(T!a1p|S<{2z$ z+@Fjwp#gXqIC6b)s06Uci758YlQ#N(eAZWV@YK_@U%`=PoNH8!j->X?;?ztkwkpwLvB?D;k zstyUDo&bq6V^mp&36Es?-e5^(&A3DwKz%9e-J=EgOw&qfqCp9;MkIi^XLCrSEH9wj z1bXzD_f;=V4R{6mO=k%x_TB|~Wc37mfNmEkK5>kK8U;C$dRZT038?O%_9ip{zw|PG z%#`wa@hc#HMw&I^1`HgUaN;5OxR)0wNXq&#$PyV6M#UwQ)f1?PEO#(;v__Y&)gGH2 z;z%m%Q;j7-gC?Jc2GCIYxOCmLfRu1QM-i}GV+p{wcvJ`)z@1XoXJ4PP`bbG7kCauQ zv!urBKVm`yxQl_q#j}^G4S77oQIx^ZSW-z-ORUE4X5ipAncb{yl5*Ms)>Axh6JL0g zNzxN&fGmxqkKv~uQhlWB@Jd;$LoAV{&i=QpER7j3SNvvYnITbBR-s=32~DdcLnA{I z>7(k06Y-hXqoU4H1gsz>H_M)9Ly!ROVc<~f;}4UIHsdq{N=kAzW5S%7S+D|VDnP|^ zIz-M* z*{y&J8b?t{f9Nc!hiL;6p|PKu1G*D;Jrc5O<_qwUly%1(^=eDu%T>zS6=W6}V#7bC zX(2-!DXDem1U24Ib&jG8&H~2ve9i(MX5iqI=)tO_BYqp$YOhD0klXub)%c;14Q$X@(xjgKDIXd@l9cu4L+@F8D)j*hlF{xoSOV}~oIn7|dKi!< zSy~R=r?Q-lgfOoDd z!yXMrc`ZNwp&t@JN9ki_=PN41%@90NQZXqxfg2X*K>~P;fdgB=IH7vj8RRG`>43?S zN~*Uh5gNeb3>+L>7_UxWslibM>=YT%DXC;m5@`TWNDnn8?@$>Q`Qee0ibzpON!J6f z>5u?A3Gi*b?^F-RC`f=cCQAZFSJt5cbY@`hyj@?a#T%z_6!#z!z|Eek4VbZb9RWSb zP*MLr2Ne|UhgY)LYmrov^`ikRfPXP?X!yny)yFuB7hr?Q5?M+YEFcLWz`%jKr#bGc_$92rXLP+6gBOXRN`MvRsh`@I9#Xn3pLh`x(%!dv82^r$?p#UJtRZj z3iDLcT@4%9A5_4#2lI7k06nFLiK`OSe)^j{Ru6ubl;OlS9U4Fi17^ykS?anuKpp{H zJ>>9B0w7gN+BhuP>LK;+pbgj;bJF$p4`?t0NMqp8+OA{qna_1f5K1!G=Q`=y+~pdq z0MZ%Qm%HIx)$|J{JOUJjSW*v%kL%C?dNHtP$ouLgnff_-1UM*?7VhF-Z1n_s0(x4? z>OC|H|VaK|G>{tRF>0A$>|mr(3gSz>rc%@7?+i**}x8w)O>4? z0W&f@D;XYGrmF?aB9D}`JH!$|i7!ABKtBO4Kh#HU^HCB-CGFKIiZ8Wme3JkOGH~dw zp##Z9-wwu`97TY0@~*}1(_jYBpMkyI7J5}lv{$T>oE5*e_nZbRfaj#7Es+^&N_io8 zq=y40OB(C8nL0Fp=Ox3ZZT90cPidS9Wi1{lxo+CuBnjXJ$q=#DKHH!A$;H==DU!yz z^PCPVfD8uA_RmeWRyw8Fghv49T$HT4-+&ds04b?%Qh(LMk0v|<6o*(+hCdb?Wc394 z1A0-Q_X0%<`h>jyC~Nj}8ms^YGEg*Tu>G>{?1x8yEg_aPsa>r}0vIG27Oq!652TJZ zIf@p~*|8-%@6-es%)maQ!Ut-5yg(kwa8hT945ihZY0${c z;GHK)0>~5~z3pr25RD9S6amg__{Put9Us6D1`ho-teLuJe{q0qA}PbDU+6IahDwHA zFJ4y}#_Qzb)!t*W1aR~9pbiaS7y}3Q?Tl1Fl@Lc!NzO67S$&ZqKo$f0N9ONRz#z>A zoHJ3fuBI+PwgA1y*|%U1!v^dPsD;)4fUGDfM}Q6QwNr;^6nV<_aAQjovI59uV9&_I zd(=97V33QOKBBXvrkgF)paBeLpg7^aOnl~ZJ<@;@!0Dm%mpU4(07giL^_MHEt+9YS z0u+;yQ#W#!0SREF^l*EJQL2Z&L3ji(4VEbBMrL1<0A6Nb@3oHVXQouwb$KXg1Cy8Q z0=&Y&fw~KSQ%$!qIEogIFEB3YhNmf6k>TG`Qrz9@jhfouPp&dkDxL7V4l96H8Q8NW zTW!A7uR`z$;AAKnw?=~%z-t1${kZyRHuW1F9s!P%k|}K+szU-8#lYSxFQ{KvQr`%{ zBftTTC8kvJcajbbV6*^#&TphP-&Yh=wxk=2CmOH5h(^r`S*7M-VosP zrE?U}$Inr;aL(qt5qwgI8Nizi?0)~taSAXT;Dm`F$?(TCjie_q8qiw;t^SYw`rSl< zQkL^#DVf#BAgd?vI-s{1+Vf}FT(xU=Y48fDprQ~aQk8Q3hw#Z&u7P1o(h~z1p5FYP3J;HsIR5)k-x1 zK9rK`HvgYB+O!AV#?Q=mpQzkr`sJ8g-ebs6{ZM7esvOt}ezsV=7 zCom1rXOg9HyzWd@+n z1xlYfO`XC&h_SMmENN;>qWom_1ZD#ILZImx^~gr0ZKS}mrN#A|Kvn=>3XsuYh03y* zJOZ2#vZOJ;-ZF};p1^+reI?MdH3ACSZg3>^b510cmcO4M0J8*WJ!b$ub6JOjP?E(3 z{P3erR!<-VXtqEf1-q*(2Sf0dg}nVH7r+W&jsTO#ma5fvjoIlXF88!4z+3@lZ~e)- z8Pm1~;juFKS<>CRwIfMnm?yxPNmD~?1+YMX_^5pf zaIMEVKZ0ahOR8oGkS{>zCh8wZr4<_-#rw+763xebvYRwAEEM3qcIq#iv|Tz!-3s{b zdVH2jdv z$goU+n*TknR@)?x08Z1s-}90M_(6cMt4}K6vIDqQdeP$qU=x0iEI$e~{2uiSciIs+ ztt?I-@tYngfEB<`0<_zduCiP($^pDx*Z2ikEzKB z>aXgipi>&VX+6&BEa^Zso^Lh(ivXK`DNtKvpJ@ZmipN}`uOZaOuL8V1_JYdbI(<%A z(aXjp3a~3dld;5Px{TB&ah zX=f>3GB{1gFF8(D^srKZM{n#_C7meiFz zU0Yy{b+u$@85yCbbT(uIPSf#Q{Z@uG0(g2)S0%Zcc6x{zd!PW%K@YzRkW&A$0$e?~ zHsH-#Q39+L;E`HO$^kB#TpewVU%oN~SSLWMtE&}oL*uA2O%I%s4>QiiDwMTepmTGk zs`Fj`HXr?do@Ch|K=%c&(?jS6Ich>#+atu1%KC8|X=KwMTRK%r#l zUO36hkRGntfOAyiPV<)_G`&pt3R(AcfH1(o$G6Mz{;{+vix1?LABNQ1k1@1 zb#}cqsT~50oxEKEuCw9nT<^)MR^xw2NsYJOz;k(vxbCRak_?lI-}$SF0Q9g^GW6At zsI_oy5hp|Z=#LOjz%BtAb*-rY*B#~Dg5Kjp^90x}K&MXsQB%4aqGy%4)OYsDc};tm%XEW&qQgkqNKe7{BoqDn$jr~o`0N+;*(ZM`vkaHbT>uOK{_0S`Uh~V zVHT{&u%Cf#+HU(@;eL0An57LkWvvAiNruT6I;bg~(>RJYpfgtQ#2l;X0|IDk)eqKb zu2-0|m*Pj&w|Y1zz}A=FQ%#@O?F`N>=(|1I4>Ni&1$h4}M2E5LPgI8%yyB4`S5RDj;`6;w&(zJlZL%t{pC7y|{DmtIpDTrU@=hlJ+y zbODYFaJI@-wbG>_cx0^3UohU*{P!c8J|V!D3GJ=zonFJV0Vjj+#mf2g7*F6Zppyb^ zSkpj3RrGS8_-5+g`#NO-H8SZT*y8DxH5+hkkMY^oTAUVO^1jX1q|z&fY{2Qmmsw)M zjFSEoVCO{pPafB~cJuA2=0F$j&HLzce;I@(nIT$3KF^P_Ktd@r6hUtOBA3)GVCkaNl|4=a;6mj?f8%YrB)9`cGKY|Jd(lL<~81&5rP#x+z{a4J#Exj zt5ba0lHz_?X6@LU0+_{9)mWp;1AJ}2)L|bnIw!0`SQs;faY^gdVpk%kByrlP{nR8A zRxlM5Q?UIH^%8UMD2IuEqqEE^r=nsCgQdVeBdOZ@i?t3UuC6@t=e=fM6mOQ*fs9Ging73)tCc-nMJ~!c?}Ile^WQr0F$v zIIW#PiobgM=3WdzB2N{?6h;+~S9zkzS+*%(-^&K9V5-_o_1EGQ;|ao9Hcvv%xe%;i zswt-6(hl|Y;g%UtocvjE-x&iEOm&;NUj3lj**1x_c!@26;S9V|F4pVd7Zh91It7S9MKd2wA+%Fer)A{zCKQBy_VhaBJ zR{b5~W}JTFcWkioL@TB+=II1HhBvqCt>@&a(_>{w^2FFo2Yrm%sIK*K4xabIYe1-< zSj7}vsJa&>xu*>B^7nG5HedJk1+aptt(d|JL(~sOx6F-bS<|6YvV^H)GnG30sMgd~ zo0G@8^|M4_;%p}B_wVS5bcdGz;%LL~1k@IqqXWMDrhs#+4gnbabB1N#S}(2w7>IJD3@6|!w{yf z&D2_Ne+h71jOYA_le7I+OVa44p3PKD=&JUCYi?(C#4!_l?JffnOnsYakp7CAK`nAx zGw`#dzir0+%bLM$Hq-RMU1|oduX0WvU*g^@@W^w!&9vRVOx5Pv*G@lm`gxOu@!QM; z8;+>6U2dKPk9s*Yu$k1&HSrkUJ1)k_<83%JQS#hjGfhvoQ!|J$;Ve4?Z|r)qqMwE~ z)BcpHYKt()%jd_HXM2$eE0{ZNrcu3T6yut;>nwcGG$hPjHq&NMv0}=7m-sgJHH5j_ zW~yDDt@?3&rEt}@`g1E!Bb)Jm=2eVq@3_vw)HsW2Y%|?YUs1J%Ybc`Z`SDFWVy$Tt zn`!^jOvRLYPQA$qemn_tyGJnv*M1vM52yRU&4gg&eI0)!Q~(L4sm)wBi^)q&x#!f| zp?01y&1|M*=U!D?x%;5%r;$Nnn%m5oZaHdhF2-5YgkICEJS`Mc_{g4hsy6poFkqB^ z{`%4niGJ?2ncm)))%mHU!&!FL-pD0n1(Rqqy#`gKK$?O3n_3_C^;Gx?10$>$U0i_(yoNyMxuTCEyUD?($GYU7I< zU%U29j1VHRYJ9C)L6!v7-#zDb?|tX@`~CBM)Hj|v!${ z+4^aVna-vsCB^^7MLT(gRdI51ii#N@JAU~1F+(PW$4rPEKQ6k2_zPhliJy##jEx-= zJ|ukHmyutLi*}wHR$ly{u~VbMhfIi^7Cs@mWL$J9uP_(!e@=)P9}^qx6673KQv4}$ zTx@jdpi*I_=y!}CH!`|RP^pO%h0T8nn{m-4yzq-+zK9$*WX$*p6QZ4ab@|YI@9<}} z+}Lg`|Ky9|!)@cEe4T$s>?9$sa5rVX@asREQOZx}`1lZ0d11Q5lf&cNa zx2NY4(0gs2gVi72Bf!!2?dk6`TM^K8`{5b@$}*65p|9uY%%26AS1>3nnSUII)})7TFJ z?zErMxK{40r-f4mxbTorbq=~ZQNL-atZ@p9jiaQ`x^nK zX1#|3sKh|_@3s;Rm-Jy^Ig(En+I}OBegFZM(oU}cRA%7C%1}=_da`gURiK`1q1OhH zYc~R{?vV!pRAJzG%SItnnw%#f#9#P5do=@tE^dj48%}`PG<+C8q*g;H)(&J_S&nfxoUic}_<}fX_XyZpzX((^XontDkrP_!vNR$6v`f z0U`A@FbjoCF_enG2%rW7kFWmMi_VmQI8i?Gq~~7{sg9!WgSA0I2>^8gblq8->XLxA zr$l3(XJEbfC5R!Y)yyRT>H(-VDqYkW1OAlz;SBr=VoI9I_w!PwOc>@4_x^ ziS4}O=K@Zvd)@8;JQ%qD#^(n8h=N`u)S!PN1HU^sVN@E`zby-(0W6gDej|RyK)eyd z_O7~t%3t&Y&=3}~qx=-GVj{@_TiAL)iLvz~*G~X=0yxmHj}(FaOIoqx6!e~3N?WZS z6FUHC1Yk(X6iPnDaIl^T!p0VMYcV#h@yrI$7(kz+exxk~um(&3pkoMrG$|N>7XbGM z>@5LjQE}wb?hjn)bPj_aZ2{1Pf&V@$-%iel{twRR^4=r#L*$NtF&i5JXbRx@k=brx z>Fm-)Ln{A>=V3(WPkOtx$21Hhe1%Cj(KMbP1g<=SSWRrzcBb7@bk$yG+9RZZ- zvzo++fI(^8W(1u0QB5Gr@_cUqod9@k>p+u$fB~Duwwul;eZr`vftSt#=*+;~qTI`J zu-gV)7Ry{822L&0Q*>0Dp#Zu7xcG3ds@CCNaWq@VrIXuv?GS0TWVnq35CGs@{Clc= za%qGh#TG%p=`q~iR2qi{-T}}Rz>VtjY0)EKYzYmd)D~Sej!PrCuo8f702Z9@sM0zk zRu|tnF)6TBT3z4A0q73kz?+pStz(TCQYdCFKI@irD}X=%yIWLKA~4Q4=~KcZ(MuwR zPOsDR0Q3MbYHedB0*j~|u;`w-p_X;uwi{*w=m{Wb8VMu^omf(fKx$)E8b0So1On&< zpxsMv@|Locn8Bw-X-yN_josVlK;6qG|;)FYR382 zD76fLnSncZRs_n;*tTXr{a%itn1y|h6Ib0dn1(!kq{PM)1`XSkwkC%ZC%#ww~=DPs&12F1%069*j zvHz^zm(s0DHGBrvhy~CeKvE%hF9DnOv3ijOqn3QbS2hG-12FGgpz7(}hPGa~!@5i^ zg?^I)AQV9Vem&$GYCG6hyL6#7?^*=3`&q7Vb^v}A2dHX2ViIQ_$LEkfhj(59__lkP z3=0kbpH=-){F#G3hBl)ZE*lmWxz+>_2B5;IRkVgt(EG~h_2<${k?W?^(>~GHrM3=W z;C6162)P;CPF5f#W0x-A;!D7l@zY`9a{%x9Emkp*Jc_M~(<+Wo-1Z5u`*{Hv2p}og zNzLgSh?HZntJT40&TrKL3<7YoQLKJ+Lbk#r)Ben9S3#($ABbUy3 ztwRh$0W3(lCs#OIM!43c%N8OS__q3c0AB#uR4HFRBG|G-2;%jFv!sWj&-+BF8p8m1 zSNlT^_Io0g;wnKvdbBtkUg3P=sY7Xn8xA0_Sq~*-j|Q^b5s-<(r5IK$OabsEfZCBRDMHoy zd>fZ20dfr|mqy>q05B53u{HhVvTnOOh%1hOt5hUYD6PI#)BXYw0U)nObLvL2kPWF| zmv*Wt>eKuEPymqtGMD>M6A+L`n*&}>E_W04&96T`GpkWvMOsGzSUlB7$$jA*eR;il zMk$|fqd%ob(xhF#vA(EmPI{+?a8h8@a3qXmh0s zfUy9+I;d{Dw!i1=p3V?eC0@=bOD}qr0x%9hhx{p01VYOhwUkL(+=T$YNm)_^#sjEy z^_VR0Q2I@BC}!#M_N$h44HlvR#ElRGa{#CHtxt-zi*R*Q*-aRH`m;fNdOyy24g}PFO;+Y}~ zdO%sq|78Iz#KNW8s{-ZQjO}mi&vbBE&;yF$tEs^NCIax<8$-XAER-wM6DUKUv~PYl z2~2`Z14G(GDW{?9xOvo7ue3A?x4nVTMr>dn-{}%sM3o$fTi>_Z_izxtR09X)TRf<4pO?O><*K1y)dpue2 zFX0W~YXIRDgrdk8#Ns4xC|89ni{+GniCvxpmRDYm*`{pMjg#0zRSBA`2}a^I#(jH;fn>hM$B>3jjPRQ&aqz0e{pq0h!KfWchWN zAuTKfaBs^MxjKaU30aoanuUFl&S6=qEWUUEH;+4%F14h6yA?N3v3@u)KM5`^0+7~! z`a`@4g!&okeZ7E}0xD_yr&2O542hJ^#uC*vILgG!r~{rMPYJU+Z#dO z(pmDoa6rbe48ZPFt(1j!I4wT7nZgYx5`jKHXB5JvT!#WMU)(Z{dj~`Lp6WH!JS63F0ILDiJvl`V6np1xx|W`6sSKVTAwSDn!@!l~2Djxc%Wf8a z%VqtU4!U=XUh-3#wE+Ci9umhzS+Wl@#_lsa&v+5E>1C1(dL09o3vT+#Mb{pVWtD2_ z#Vi%HX;>Ypd_My?yt9LR^0JTnPM3E!MmC*8QMf-9rs|36x;pbbHIsS`m0KWnlIAjNvE5$GyQ-rgm-}x*vZp{Gj8w2V4 zTZG9)*S0()k1JzV+(JppyYS?4n~>hBb^VyuvUNF$5eH*IT|0Yyyz{;dd4E z0%It?uC_u%O(J0JFnkt(e*>5lJWQSR+GPFo_)>58rgk5%0k9dsvXE_b+o7PBLX0>| zSPJMIN)2oZU<-ioq6gy73~Uh9jSuC+Ug^Q!d*A6y09ygn&-qpXKTXwQYy@pedvqJX z?*Lk!eITXMz6M=GF1<#=giGd27vwo?1K{DaSAMa;z8-#XTD?^9Sy~s4fraf1T)MS! zpxCQ8hHV)C1mw69OldW*-82lq4p?x1c}(tn_6Duyxu&@`v z!sGpww(f(ra9TfC^TAqfmpg!c00!(%mP?|2pQu9lQSHkMq6_f}c)2T7JK}o0A3)p5 zQOc#0hLGhB5d)I%qUMNX`5DRq0HgC}iC@Miy(fpS>Eu#wJt6%zI%&&SWkQJumMM2LZH9C@bf*{g^mtIpYkMf`2P>78VXMaPjJ?8>**I zg!B5xE@8$C3qAIGqyab#3k^>UP;!4hK>P&PvRw3%0D>|d@(y$a7P6b>C=p0AiZAz` zT2A_IskRdqjskF`sOKVkmRPFgE8J_Wf5JkajVqGJ1fqj z>k+wht}zcIEi*f z4o}$w;0yz41tT&PaNDj$AWw_2<#DX6))ZK1e*c|v>9)~vd9>$qwVHg-$kibgK!aoJ zMb0_sTkmy?=KRe9Le={FnHacq7C_YOMkE!Kr3*onOsK|g8hqRjWfva+a1KC;sDG5U zrjOKV&BYo{76KD<0GwwaEn-DERjp5q`$K^+h4)rVP@?=+%LM?3YQHCKA(w8S(js7_ z)iUbpEm%kc5H`U^V~2qIkz&uq2S)|aG;-c2kQF=W0jXA4 zxC9{Kf`@9Ln=~Gqa+Zn#3vWdLNCz+~{jpNMESmoK4pzRtfSVuxF01u20~dZsr3|u5 z*&Bq~yRn6W@obD@c)Xz}EL;Jw)%mG9>DMpx^|)ACw|*q|r3|>#2J{_67e#xg-lYbkL~Qe9U+OSfFhDN!rcP!WBXLq53kTQd=Ae=R(TDzc;rrkg+HZB&FjgB5c_*$BQsPZc#TsSfZGh5 z-)=i0v*d81!@@U0-i+lcA(t}d+X38xg~8rVYH!V<+@rkTe5PD79a=Bd`z|cp^?#}k z?t|!1`H9~fLp6>@7Qn(iSa9j!uMX~=L-Zc5W%>GgY*~5cE`a|6h+2M>9H*1cgX5H? zygH&?`Ewvso60Q!+=ol6`i6=*!oUYl@n0^eRw#}L+UguGEj$2lYNdM2v46P8HO=P$ z0in{U*-6g0hYXwxI~yZ6GRMbvc)=jxjlK?;cG}|M(jx$Ek_OA`vBTY{B?GV|`qcpN zn1Qq3yJX0Yb2RQpMa?l3>P^7x^g#yw7cBTTV5tyUa@5CG#W4s)!-w*w)W8H-cmfOe zzkMo&%+WB49|8y{C@p3HKgC-;9?k*q6u^L&Khp7%1rK8rD8!OXSz5o<4B#1nZrKrX zBXhLGvPu?;0kx9q0eH?p>ZM^-RM1@l#80rz6ym}~)mqh0UN~L=I8x%2eAjSvPS73q zMlb88Q===wLNS!$fBpLoQ-yGFWH@e}mT12J0}NgLcVu^!`y;$b*F|%X#yslkSMJBA{5% zB{Sp=B_F{2$vfm@zoV0(OK+h3;^0g>r*wve0sxnjW~j9KgzFYH`oZ+;;pzb1$QZ&O z$fp)ZOLmU4WSq45!6EsmRtO7Om+I0`pcwpW3)vV4U>j>1RcpW2tze-DK)anYRe5(W zR;_vsHGOnLZe(u(EUKnfI7g?;x<%%IrI}ZWvFh%I#iaM#C&ssf^{yBiC zr4Io5ce$q0+O-W8#Q$~aajMj%e*g?9J6ct1&le(1_%57b!IE1F-+GL35?{RqGr3A# z6-IkQ{N9$}6o|NkgNg>TS1~0tQ@@Qv2r);}b&lTrjKLLjygPQ97-!Acr?r!_%i)tj z!pFfFdei;Dcq|<;B{gH}xQ3Qb%2+@%aec=(zTQ;eje^dk?M{rZm{OVv={8#RQ88hZQi!MKLMYJ#*#y?&xhOl~F8aM?oxdP+Bvu?QU|vcJwILBzm-( z8$Y`Yri@}z{R7oX;plu;zhN3VFt01Q3QSqeRI7hTb#H(%YYN1@M~IQgfyK!Sbs1Am zGYul%ss`@4(P&^@%jWt6M!-&a%`|P4CWZ^=x|>Ch*}_#ErLbeRIj;j#K{GBs;>$Js zMR-ToM12!8sNSr&yK(Jk zl`$`n-_DDvqL~*B*2x9P5mKaAtRlStSvqgVP8L&DGZ|^u4l8~+$D{WYYt8%xuv1+zsntq) z$d%C%P;4F?$2aNl@*4cMvz*m|B{tmcCTpRUI}lNaQW7NHeAlm!yu?*36SJUaEGw zzh;5rI?%OaF8WoTX&ue%N+)UM8tq}!wb7@RZYSiK*450nf8A4ym36;SqdEtsi*@HA z*Yz~B|ERlK_X3TT)zEmWqn7MBH-q`JB#l?~>$SdOj=h8=g!1esaMhzS|Q zw+7C^QStZq<1*8`)$hU7*GzW)T6s%x^cG1MpXYI1YmK*f?e7Q1Lo*&#W~yA<(9mpW z2A2`lfmauqw+5O?wtY}rxW$-1`54~BeA#I`>@?KOi0PrKcEVis;$ifHngwd zrvg=RJ{ z*r?`qzpi?pEX3+V=M(k)OfW4q(`EawYTpUDBUGCIp=DSxed`#S4aQqDRlIk}r+9~5 YG@pDEC|b)n*-36Sl$db*6Vev{1D>Urwg3PC literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..efe3d1daa9b18aca9a328229f52c1fdee415103e GIT binary patch literal 6148 zcmeHKyKciU3?y3*2D*6cs9(s?%tDZSL4TlVoy}-~q}@H`yYgu{K86}@wq!Am0Cf_@ zlb{<+Q$%$0+pk1A5ozIua<*Y^wr@VOr;I2Njx+Z1BX6?qPhZ0<`~QG(TRDPxlJCF# z&%9NW_1kfN@pRD|$dL+AfvEz| zvFxn=@8CE3|0#(pDnJDmrGPfu_w5Ex%Gx@6oYmR_|A1T01#X79Q!scr26{Qh!piZh cCq-VdIreMf80d7woet#BfayY`0*_YU3OipE&j0`b literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics new file mode 100644 index 0000000000000000000000000000000000000000..93595aea0ef5d33dc2f21382a0968a4169345114 GIT binary patch literal 15269 zcmZvjcX$=m*2Y81ApwLWgoLUCDxe4k0YNX2p+im*2sMbvPy(Sy5}JUZgGv*XA|NQ@ zfFfOjG(p5sX;KrB-UbjA1PxapAnkkCTC?Yz&*$Dh?(?3#-(TLf_TE#@QO(e|OANWY zp!IES(6SB{>X@Oam%{M><8wo5W<-aCgyam%9+BNQdq}TQnZt$;%pRIs42Fz{Ff(M} zh!I0Fdu0yoH?a56+|Vf*B``Q*OipI6;RDBJ4$lo6pBr8?BLaiNv$BVc$c^xaW`yC> zz@Z~@eg5!_aQYkBLkHv*_lJ)hiOinJ9G_dPW=0hKq0HeUdJV}QK0G(HjUE%V_1ygV zA7+$N6TSQP?VO!cH#e+}{%~syhJ^fku0d{i8~ynLnWLx?17FU}jcB7sXHg=OYs z_sv2v#YT?LEnYJVNu{>W8&I!ay)69U+ULm^l7+w1;=auic=2b38Z_$jhXRUZsKl5m zEnKL692psEn_;p9~P)f$*uFirmtgtv4|vDRW3M5V<_TA?fo+ z0xHc=*ssT)a~GlcXdAN(rN>$h#cN~~Xs^W*P?_*y zEPehGfTATSR_pC1wK!?>H2O*3X0iknbMJUQG?GdJijkzOwUynZmQ%$EsAB_1L1mNP zvY-K#0u(FIV{L2NsR`{0V@nNYd5|EB;5)Zo`AVdY&YJhb863 z<8|Fr-R`E&E|k`}uECN@n={UaMoI-hj|lYovn(H3@>kK>MZjiKoKPWtk6YHmfF2X*;T0Ve)Y@bx zL0fc|NGX4_lm?9zdj!zq4E=NM%Z={7_*aP_D;-mp-fT}Zezj38`YSkjd zNl?5AqbxwFO*9OkhLkox{tcB9xDRw*Y{JAqd_hzP0ctXEx54DcRDwL|=*Hf8imrfi zMO~u>ctR2y^*re=v40tvWap2HL?7jLudwj}B|Ry?v7LR~l9KP3WaEUuYQ@iPU||$M zEh*{Qi2CHhQSnFVPy)O|N}3+^eIg`)rx^HWj;O4X6&F`T|`Z_>@W-XR?!^jv7mX1{7G(02(mxPnYZCRmudDT?9Ppuq2@G6b%}{ zvkcsAclo9|>XpeP&3~Y?B;bkWZu1QpxRbjygIu(U5jMMMd2d)O0VE8h*H21lBt0zn zSv|E<_Rx&1R*K1z67FTqfCliK^ia9SdAFn#z6b<(#_@osE+-1`yd?DdqM{3=x{qCI zx~HLlxG84|z|uBA%6|mfcSF4#q#QBXNzk*TI3 z=rfii5}HWDu?gy2rK~e)iZ0nM%|%k*E1x7PUhiw-0bS}<_C?6|F}`5HE6Bw>ib zlKPlksSp}KO9pO@YYuySG0iN!x9tEjYY<3aw5-HjBE%`GbkZ6j*E&|4wEUByxRq~(# zv=!j{%iZ0jP5oGB7Xj~UED7-SB?+LNl(h3+F|`-J(~LBpX0Qa1@aS{YIDljZin{;h zbN3=G!)6yH>=8*#r|)zyjFoBy$S=^HA460h-x4D!Usx=utXH14p#h{wS+8sxrmn&v zCc7wQoW&AAx%ht^XaK1KwBD@tBDKHHE&^V+SrX7`xZ6h>1Gi=+)mAGtL0km9X?nm7 zD<2vO?ImISFW;z@`Xc}g)jeSRl^g*&NJ4s_V{X%Fp(eX1VX@5;B_))dU_m1xU4Wv5 zefZ3WFh_@yRT}TG1Q0i5wVTjUfR83`Cl{UCr8>JPVWr8En(kGChLO-ofZmhT7qhg= zHoFM;&R|IiyOxs#&{=?W@2c-tsmpD45is3g2_WA2oHT$g0vvgKi`uy9Cc6l*9F~+| ze@7BPR{_rMwDFmHm}o*t)6+=FHIaQU4-!B(28s^7q8`a<18jCt!q*y0D(S?k8PGU@ z9RX!9bi2Vm^<^xT-e296cN~_0%2oM_AOPK^tTodT$%Um&dqig!rOXvcDQ%||Vi>>+ z0#qCSoqIge>YE<0!UM|n+R;IP9#Ybi#k_k|QUDn30dadan}0DL}@gW#pnBR_g4c9`+h6scEY;4P&J;(8o&xO)09WPVG!GO3Fc- zC7`&9>P4)V0KbKeRzQl*E=pMS;AG z=|Ue`ERho5XoC(7pdSOb(;7Zd56fvbyQr*>$RsECNRkE#AX9*eg>i0KsXa~1NLi-A zlKQBb zP}HoD{s145r6#+mhYK1@D(Q#wIy4dn3Q*`2s-^8pGm>ydV@W`2_k3sogQTSF)9Trs zw%BABC2X`=640(^AvA!&0=zIQMDN7udfZ~tN)S&^qA}QS#O~Ge69_^c$P~3ReP=h4_ku`H@ z)aM@tXtY4L#(tnmd(~zqrF>W0C2-07ajAmxfK#20ldndP*92(YB1QpS$RuaV(pgf% z7i}D90Iy3)ap&&1K>K1QW(4@lVo6QUtuX@{fXTqkn_blx_O!buyQqgdA_0_}ztO=k zfC-ZD!u7WJ%r}ESAl!t@A}OI)YZD&;ydepHZSCqVYx_DhBPHDw2_UZbL>r?3-jsyF ze)X-meKkx7&^5|oNeSm8^PmB|~?ix-;o|( ztNDZ~sZRh{qbnfx$PRb&-j$Npj~R_Q-U}@NTqM6VJ-ed=i5}h)VDJaOs~$?4n31Oc za9C2)4c{$<1~6TK6*pAVX?sj|Q4c$8mIRFJW4$^7E4;Ch_gC0 zfSC;3(tP3g%o}&rgp!hWJ1hw(uH`@jm?a5Y20N;UP@7#;(s7XhVn2Ps#4v#O8Mt}4 z(YNZlZYnMSUCZk*@hZhVvdn@6Fk2GJeL~C3Jyf!xWT#6REUD?$H}atY%n@MIfb1J0@`DHIoilmfjoeX>c z@Sy<9%6m7oTQ=GF?dGb+l7N&$1~h<=82Im|b(h?Pbaw-&k6jK-@?aTVeg;XOe=eW} z0?k`<+1>H9x$r1H*dECIz0RL}4a0a*fVC{XVs z>2<_OP=TYMvMGObfCjKgfDbCsbAn6jYO{-g(;7>pl&jt|9~z}C0JK=3RTtycBD6Ny zNze&WvfLftG$8?eD!`tT>ZLwCMZ=6Nz#*F@ZD`lO+_F9s;M9h(YHvCk?4pDtjt8uG zCJ`D5O9aTvouo<%TH<4x0%G$9>jHc(z|Soz<1;tj(t?tb4vVBc`MY=_Mv<^ofX&N~ zsDvTp7vQYLl6q+OB9(-dS`26zLw|?nsGl&?8=CB-K7KY>0*X60rUNv9j~YuVYjul!Xe6u@pk{P6)khbTU6c?w!R5y8 z(gj#0K#iGsYA@>2j5Hot+PE(2E8J=U%53{Z0qM2}9MV`))Ax$&&?xB(0is5}s*ZXq zn)&x$l+CNFVHCg`0UB0QkF@j*3o`;7&{!fN#yCZL0U%ER{j7H}1)bO0fI*R$R6-h= zq=&#o6tm}`iBSOS1sMD3Dz$Mz%Np2+n8Ak}0X7IQOpnKBKH^CbZUUVKNR$-4xv{%S z8wJSSP)U^(bohcwYTVQSi;^}8u>9sUb@+nrhCkXcaYD?SZQW)4N`TL6>oi4o!bTdD zG`-zsNvo88+JZ*HW&sw?x71k*DrqPA<$fFU7BvK5ivXLaR#QM5ja{@4J1v&f^orCO z&;Y&`;6jf)eCCpZ6FeYh|DqfLwhFMf;$5{5iQ=M!pna(Q5d4(zjQ|BTud9+;IW$F& zv#x+$_!6cBVnE)bTQ?w($Vzk`yt3y^m5GkoSAQVl3~S#6dC zRE;vA0ql^H&e~X?yb-DHdoZG{G0Y=<^L0u0)I~}yF<1d(!a6o{**QzTZXqEO-8QeqcjFp7~928)5 zY4uOr(rY-FksfS=C3=Xd7Lo&v9`Xg4dOJe(P@PNy1Rae#>K~;P2=HpbDpk_+mInmg z$kx_2F^q(t1W0~-uv*r30brK{lkCH$1#U@)1eo;cA_b&b?4rZB!(d6r`|~Dl(?1Jv zsPkW{hiZ-oeC+|zO^Ug|VF5;cld6E{$uCWRxZFgPb`4n z1z1w+3O@5bbTnv+UmXe@mgphu+;syQz$F3ZHmaz0+OXM0J?uAG5|H+y1r6Y`0IR?9 zF3~)9%ytHv%$IAF5`5D|OFfd+6zfG^*7)Y0fnGm;Qk zrLYXYg;4-k1^94jceQanX~u0@XGsa46}O=Q{3*a({dENlG}%QZ1$qcu)lw7SF9AMY z;5`(BzCRo$R5q?_eUeBh6kvYVpEO0+L(nQ6HOWP<(qVm?>yQAh2{3WVVzo+rX-0Yo zykLcOd(^@xfd2~crO{XQ(9OV%0D)&*ShMal4B&48rbP@=%bKNQMt~C*ODbu5v|%FIfPVy7d?7?7WRTxoT9LH0MK>ILK*AjX=FZ%! z5`tFhm`wy7kMI>Ub&@{+O+a@Anz((xf(DRLQcjsHDdoT?4m5y!0?a$9esN46WwDF< zuvr3#C|^VxDYpRK7sx79zs;r(5+^}{FS6lT3w1yM9tg1Jz1pg*pzHf56DH|p^;r#a z@a+0@E=-wHa5fx{bqI?)!Un8*CO@&tzXk_x&)0{4;dqt7_PgnG;e$9#$j zYUZ?tL8FT%2za)aF1DC>}hqW53m)&|o4w=I3Lj)WHv0gTNbMMAFa1gYkLHnkwpd@bp0%CM74ZZV~no8U<6_V^%II zQnh8;G|P{Kg9c0LC*OSz4vX}d-BXguNNW(p1fB}vSNiBMAScRWwv^tkn85~=kk{Gc$0w!>j2QKrrPz=KP{B)ItohG5PPs^A+o-V-!rh z$81^FkBqcQSq_vifeSDEc0UWFU=lp$x0{Q}NKDYPFmSlTZ1{v>b5JC0Y(X;zqhKm}%+<%eua-fz1ugjfbX%B8 z9`k$4_thr7On*>vzSmh&Z3F9)MowjqxicmYLvpyuCrse|IHLY^1EXN7c+8Cz+trh? zyMsw#cIhlBXU#eX8qC8Ub8=lHwWmF7Hqt@eX0jw^W@=tf`lydg|Ak2y8#sLFZC z!KAFg4<<|0=Ie5vByt}0n5)G;RBI6QcK4&LnDA@k4PhShm^&jQF~s{5)KB2Ng`XU( z3G=we{95#x%IR-lQfdoaa}nVmSr|oa)jj5yqB>-xYa!@;D(D8hJ=PJXhGK4%e(Anq z3>)g-IpMdaTNp)7O~qXQZPF06r$G<9pf5k))-r{8!eh?()xTPqK9nXUCvcrc)Xk<* z%KZ0>5T3%SRyC<_T!|%)b^MQ_q}U7xHelc5%+hH1XD*bH^Smes{2m5#zxA?Cnfzn zR>gz_^R!~FpPUz|);in9q`N<)q@QcQYmi_PJ?6jtlT|Uzxe-@Sj{ a*xR64`F^Uiq_ytZ-Gm;QDgRY$o&N)sXho&~ literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics new file mode 100644 index 0000000000000000000000000000000000000000..bf59a3117b3c3fc93034041f7b293ee06c98e632 GIT binary patch literal 15058 zcmZXbd0fru`~MT!jTqUNCbE^iQbeV_ltbs7&RDWkWGih5QL-YV4f_T|3MTbAvN^23T32fcF~vf5CqWtMHk z!;<3P%w(IU7JC~Tn-M8VX-Tn3!}^YlPe~n|l$cyXyungQyfSQXTH3JqzVV51gZn2Y z+fK8T7k^C47!lt$b?}(@)Z~(x$)%cF?8F~a6OvNWlI_B6EhWWogA>z|ONW=Tl%mf_ zN*t72CcIR7x{%pN$jnSG(G(w)(tmJb-(g9qsmZp%zV<~+3m5NoV7Y1hO8?kcW73E= z$t8n*9emHeKRWf||M#MMa;ae7c8Rh6k83wzX=#Ip$0yqb`*sd|^FRns&rB}eG(miN zrAMF6w6gO`5TDhgllWnD^8`Wqb#M0juaPkc0`SP+;s~G&0~Ip!;w!IDAfWZh(eyR~ z0#D8hZ8R^G0MFsucLONPK)GKp+`s7NLx9`6Owx#ekehY)R=fK*0UbKs_5@IlfwDmk z4Wvq*c?%~B+X)EoDx_w{GH|NSXaMCIC}ZF6pSw3#k%W$>&Rdqgo=rf6hql4<;G(kt zDllMoe`Vcyq0R(!4tYJB%n?B3i*R9cd6Lkv0F@cA{XBCh`I3OGUK9uhQk-v+qC$etggX}i*aH~dwQ>2B zA_Dy9#PhpI!iB*q(p^UGl&Mh#K#7`%L~aof;IDxmPJJzB-yb6hU79Q@fP|_Hl$cX; zVodo|25z?zVX!BadJMXzBJH2tYZ`zr0XX(vRuq*_K;W1k`C|!KUwSGf3`wY-`UyZa z25eqBHBKoqGVq6Z|IukI;b>3gC4a~1!vHt{7`y0$OunG-l%lAGIRr#iMbL!=|I#+~ z0aOPN@z_DeA!uJ!&S?T%imOqEkxF-p4+5w`fY0ZT-dm*&L3ihiH??93J-szB^;rUd zngAxZj1=~f4Z(53zrs?IFyrd*omUfC!mjeI0n`G}&}1RS30OUb2nH6AfL07VaQzKH zZ3aHAew0P|Ab^m{%Q@pn!Y?8sA{EJo-E}qsr~@F-<~daj0nx&hBD@6LbLY|^FKy0y z1fVVhAGaAVULC%Rft7_KtBeE~DCR^yc#K&A1is0x2H-0OKD;fcEORMzA9Wc{Z)5g4u2unrrbW&N z-~?cEhnnJdk`Pu~ST8c2yp#`f2$H6-B_CAu?tDJU)bCV<8O za@$s>4n!(_YC>*e35Dnf1YG_@s?-F);EZ6>Ljd8Ggl7ax5}uS4QU97HoQk*%2~7bs zGX5kRS-6vKWG0&vTBWinZviv|P~y*BN`gO%Ku9o^J(nHbI5-myqfsb2y1851L`k|9DTf5ee5t?(wnK}lm zYz;~cIR&5<18+~abrCKn30?br#kE8#S4MU!ZtZL4w@%>#{ zqU``UpDQ8V+I5>%rHRfwy~wT6d+I_ZHvr`y?VwgcMqehq%d<+8scAGI34vMrX8>?# z;NQDN6O~(^9Oe`y8}h`Y5ay5s*PvGbJOHfyeWRRl5e==Wu}JI$!ZHB_*?RT>&>q0j zmF1)j5v6qjnU=s|q|(lxWwv?($k?-A_ST5DRvUJ>s5TSab&M+{cme2>e1t}aFgn6p zFDIr=uA-F$uu9vDo&x9qAi#Guy+puJYdvh;AX#)1?zGOP$ zK{Jbi^=HBWcmt^Y*IQ8r3{2dwHZs%ExwNbj5L|o6KmeT>c=KRnqrKJcF)&Yr0qc+H zci2x}x>oE9pfiADZR2I6BQ}KVNdKw56x81R8UP;vQ;PUw$%e;e)O;`%&SaKU8s-!L zz!$()j}Tg=2)Iak^Vl_Ac%r>DvQupUeo{iYq159D_-u{CgN7>kf@7kq1Ly)Eu^!D% z0wSyViDXUXY<>K_+601pk2(VI2N0YYM1zfhR)0_hF;FDx2Qx14>V6rx007m!zm)Ah zGP;IZrOhS6con!Q zqaN%*tG}}V2nJB*Lp!BX5Sh=7%xtHJVor#@AsbRIt0mEFZ@ZW>jDVV@Mgs@~uIgo&f!esF-(FDvjF^mXE znIKzbIDmx#ji%zWM1&!^)d0p3;#QES*70ve^>f`n)Qmus{TuQAXLML;TbrMU5C z9K7zh-0m#^j`p;n4V@%J(h4WPFj zE%I!`ZmUWIE{UHVI9vCf7zUsZfIbel#aLkA2aFw(FhHvm_{lH?Kn#FjcVD3{1M4YR zvFs(Qgy0$1rCa*~2+OsXt8}*=)>hfqQyCpNwUX?u{Q!7*+RICXZU=NTHmB@Rho<0E zml&wjA3#9jDB1bC9n&+;JQSmXyj1G3H-G^EYF)e_E3VtCN}O?I^l-g~1~qb!2aZ?( zHLL%k3h08h&5RV*%ScbGDR=KU23|h(nLt$|a>-z4&DL>ZZO0KmaQ0^j!~=Nsx{iv& zW9UIDjlILCOvvc_=(A)`mNae~o<0OW7Xl5sFJpUK0S zBz*f&(hgX15{ozLKqC0?;%qsJp8!stNH*dhSK z036GiE3bSB^?PqnqBBc2YawTZe$1v0(tlG4SH*zjmrbhocBcrG2{d`G>~Hz-0n>w)}?j zQXhRRGw&AbqS%*(4FSWw&Pg#aRM|giG2|Hkp0Q|2HtO%7R0Js&gP_@bg zSW3V)@i*+=0Z#Aa*v$g4>2Mc$QZ!^@e94Q zd9Sbwd#nG!%?ANY1h92o4Qh%Mhs-xusai?GZoOayOl~0WW^w@h+@gV#rGqVHz> zx7U-K@e}}?8f}zIqG6JCJ{(8!Q5?=Jl^fYq00-8D$yM4g)H>r1llgpo8es2GA2v(_ zu>ENdRnwEGe!gZKe!TmfV|MFl;UsY~B9%h_=?P#ifW|c_kgU={dOx41&9{D3 zZVhbcQUD+ifR`mmN!VI~S4mRo`edb&|BKjt_nEzCADBkm?8F#e0FP|k6ca`0w{a1oKoqabzizq z4(9Jr0!zHT0+sRsH2GgY)$TW-phVxmS)7cXU-lGK zS_U9~+$&krQGHOzWW!TxYCIJEQ;tdrKLIG)R^7ry4MIaD;H8dqz_Mx5t;-p>|9k|k zBy7XTJ*uXoiyBg&<@9#Bz62_*0I;pLj~xyiQR8uVBMC1=N5csvpxk76#9s*@y>hsc zz>6_=jp%p!SQfCP`XNYI1;FoR7n!Y5Kj>y0U2NoVQltxa%4}T?z|~P3y#PCpP|0Fz zZTFo&%dK(^1NWY;`y}sXqE@_6k+#^Oyh*~Git<>t7QpkKLDcRjHCF0oY$>fb0#I`y z?UHQtIXxdNb!w{l;P=}?*(x^!C}C@o%SqG)yo-#kg%ws9?UyocH&oifz};(4-m7}J z2R#U=p=-KJ#rr7$wgMPi+a$l+je6lM<|*HtSRAvtr;&vEQ856v0Vw$4lM3AZ=fWEP z#?Mk;pA`N2ws!-t9l%7NIC--k^_o(e$F8Me4$l@Lp-VldMF4gH@SRv)N-)}ut0q#2 z-P&l5YL$N8M@|CR37}SGQkUFnw8Oi|XqN}7fc$P$Sqxwo0MFr{WUGuS7IDUvm&F-v zmW&>-qX&T94BYu#%2)i(6274J2&0=@1NYZ>Zva07_>e%GG6R)Ma{`e{&icgeXZ}q_ zdJlj{(>$e@jFnOO1h`Vu;I9b%TKUS=;THglzKNFEYILUD<9j|!Yip!K%gZ?I1u)5F zvTXOp=E!N1(C(7x(0sq-*Q&-8*su>kL2O5v-o~aPL$S?RTYIpT!evIHg9jQ}V!Z5J0<%v*i7f zvE2@}YgpXTmr26jy)rd^1etae1S%b3;P%s_17#yK z23h-3$H{8J>f&=vw)?}7;Ns~fuUd?L=w2)#LX9lHMp<&v{S6YXpLi~7+USpH5#VR7 z=>_{`n>hktb9aAv;4tX!(JGgGw|AGPtP=p@ zH+H0kN+sG4ofRTqh zRW8K|-DECx&_mI0r@2`CRh&>x0q7E)DUUS9(LYh0ahnOC4$T+ke#6?G0&p5Y^L>Y@ zdyz_MX(DoO80d=9Ae9z9EduZd0M}FhQxX!W@j*hUwJ-g<)d=7W0H-KLH$A{7|Wsjqy(s%+@v&xpFQf zT!2cp9TMoXNJ9DpZ9{~wcoRQI=C^i|oIwABgooq0%aLUqd!Nb%HY`znsY}gBS4g-B zU{m@Qd89FZgYLyfQ&#b#ZGP3ZWdXPZU|Gj`@|waphYsxgeVirQRp`mz9{Ww%BaZ`@ z0nE>CsVeR}YXKRPbuLZ+=^!Nh1z>UU990i_a1E&xT|ueTv3X|+Tmdj?d_~n}=3ARU z6dJjRwCAG=^C96XfMMTkl$(HYo;7g2(4qw3Ad-k%ge4OoudUPcNc?X4Xa zERYAk>i~NAexZzBX`Q0hr@TC8B*=;@1Q0S~kO&t?dg(RYW@2#gC6(G#l-tY=2LAr& z?JW;}#&tc#^2yK5SO)Y}_23m(WiM>F2?_PKJW_#MDc&!thTIzG&E}AVm^f+0@{t1^Whd$dN;dJ8NCg4Nx*<%DmA>?$IgO;+W^i@dM|fx zMU;^ zOf|>p)JXYR16NeVJpho{YmMxDrl!{9OWjYq6XlY-kF!j^hXBkKHPCpZ-n~cts3I*) z91WEo0q9e;v&>f0SJa{Sky}gp8#N!?*Z-CP;4y%huq(8}h&E%YMQMnxk?zbpFS)he zrpo~S0pNQiPM#D^^~LlTO_1V{`GQY3WJAb4a+C1V698TIjFT&zse?5&(mGJNlF@GF zx8nT-JO$A7`dS%jQ){?}B#g|WcQJ6ZZ)*V000jBF(Mx1Q8|!j1ZiTYJy}{H107U>w zI$fs(BA^pm1xc8oz2vt0{R05c0n~1trX;i!jSr9KS;jrtN&5jAeLhL5^a6lq)$uBq zT9xB-J=;KiER}7@xW5?^UIJ*b|CJmlCRezIfQ(+EgYaFYdu*tzhgS^T_`z3o4O74j zttycCdIJ6-UEgx8Q@^Wzwmt^bKb1~oN)xZ3SSIav>HAmR6rZM0yJmqM-E%qS{2 z-o4%Hc9H|-UjT(0ePkmu`8uwmAK9(QbW&;MWg}F23zbTH^p($%iPM(O*3mFW03Mag z>;~`-5{^H6B)f*mNBkg`0+vuyZDek}!W;p-2k`tW_2{$7=R6S%WZY3pqWkfyvK~GF zczlnJN^G=SwC;QpPOvbtA^c=hNGJxd>qsb-8F{HQnShNfXSxF1YxIPU~tr4h`zv|>X|0G|QOc6cnmJvIfF)|>ID z(`xs2Pa6UoQfx}t*ns(~#yC~h6hwR#WXXU$3PsGfmsv1I;Z)w7U}G<+^}wU+C2Ac)P8rS2FHWToC*8tP9Gt~e&!Z`z9-*r&V9IKy zO}SAt!iX{T(>}<)ubCas_kt;>na-nci(+NQ@`C#_%hA2vV{JQ`2jw-BZ?i@ohfQ5E zDw&zD?OAmboi3$gRecSS%(c`hSiEFa8Jmo zqL{)D4YS0bnCa6rQ>`>(kMZt7Ox!LvFjX~kr)r4Y{Y-t4Fy!l;OI(S>WaX{~^QC4U z>`hbyH399OnDI$m=EMyDwJw-ynmL!M9v3$yU6;6{IB4e4DsDF9gML}! zMci9ViBrwiV@0!JV5%$ThV36zo&T>7CQ;(@wXjE@XVbvcP)y;;Hf>bP?MK1oN7*eo z2rucDSn(#9nu@s*Y~M(pgH3_f`8omnHZd!TUxBHmn8NKlBjx$u)X$pNQ?28v=cjxy zwKWr8?}%D0;;pO2tZtM?T-!r?Hv&^fF*hn4%u_YjAJdh5JqD{i>E^r97ffBn6lUI1 z&s>-WV`(O4njSkIBU6uqsi&Do*FUK|h`OwI?i@#MCnV?FE9=43*Gx>-G_}3;`A65n z9KA(*4C%8NOasjn)Z|c59>fIeMS4=zP%#+T2P*@m4;(dkME5pK+Je+o$ha21Ubz$6MDx( zFakxPD{-UETO(gFb#gAcZw->)zGn4R+^dZ)=t^ezqRPN{J$kEldU!P&A9NW_1kfN@pRD|$dL+AfvEz| zvFxn=@8CE3|0#(pDnJDmrGPfu_w5Ex%Gx@6oYmR_|A1T01#X79Q!scr26{Qh!piZh cCq-VdIreMf80d7woet#BfayY`0*_YU3OipE&j0`b literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics new file mode 100644 index 0000000000000000000000000000000000000000..98f322b3eda04b8cd4b9af9dd776b506547a1377 GIT binary patch literal 15256 zcmZvjcU;Z;|G*PTlT=!yAwor`W0X-+a#hsnr0%s+BBY!YS42csQskBf*`v@RdynjS zUE@YtN=DX|{d>ROulM_XzJL7w@u1H0y!Lp%KOc#zt2HCP=Z^SdF=LOPi9@2QZ)0on zXM#e-Hn@X|ipuPm=-BA6=*UTP!(--5kB(BbA|D9WA-{>79vd4OJ}EqE%Jj)m3e}at zn&gME^Jj-onln8Y2{qSZfkf>E$G|Hb66FnRQBM@4Hd zH_xb3?f;+O^;5L*avQs^GSKr*EEN?yeO9A@ib@3e;*n)zeb;Pjj3D0=6ki_gux5J%;O6CJ1dnLztxBMF6#8%IuQ2#E zK#v=>8;1V28BK>d>%n`1IW9Inn0mbBSM*^VMz{d2(=-Q zCPLjtKPNu~$S3ocqmkzw3SHU4)f<+uEgK;<0%=j`-TU%~-oy|B`5q}54qp#8HAh|` zegpM}jdkmUP+J0NBlM)}15<|-3jNf&&Y}0tlbB7zvhrpiq)s3m3N=^UIirh?7&D^l zelU+n3U2>7=}znfD(T*y5QN$hNS8uyk4QJ|Auc28`I)wxHo<5Fg?f(>(6&?c2x$AJK>W6IWMngl1A!l?m$SJKphb}T6tSaJ_(Y-+e~0_ z3O2vYtp!IApc4X# zh9A|K6Dn48<}s2#dUA{efw81aG5(VQUrqYCf*v#l4ps>LI z5HcW;DMCj~I%zS;b2$kOCh4TLZrsiPgjt|!W)yh&@$oMB6i8V<`}O~vp)^k;C%<luY$@?4*irb8h{x|=X*|JG8D5bn zr56IGwTF1-K613o#2FJD>qlENB;t^b#}Md^cKB|41MUF{{}qW28DH5O8Vg8BF3U#1 z4kcJGI1V8OpdxuLd3*P&A4?rM5GuG1K4r)0qyNOczQQ)Qgw%rg`-vuRNh|aAz2C@m@K7R{L zDI%diO1QVB3Uq;JzZov-VnGzoTK7t~3DbnY3v;wqzB@dwcc^H9U_z(o#jMA;p4w=u1*SwIH`tY1`-{r;Z7>a7` z{N|L7v5?LAMmsW1kx9e_bU|~!p}ro~awbw-Db%RDqAH5zm7RChN4A&D&*P)aJ<)~r zS~mo80-Z>)(0H`O(JlZ|7ER^rg?rSZ!Dz=Yv?Eh>^K1lU(nVtSE$W#EjUD2VlNf=3 zJK7;Rx|_!%N1e6}>M5|BwjB3uk07UrdxsD^hNG0@8=F5f_YP@C(-okYUJXpj@DbIy zD8&PT;WqjOSR1l{3$J-WEs52Hi~+TtoJ~Z)lLC!fM_98FAZs0@4qporqW>fx!bF^! z@DYI#Xh&Wf&6~`%aa*ai0IWjc2 zi4ulQJi~#Ei*%O&2`ds7L8*fcJ9e_#_zi(m5qYfoWKnxB8xe;Au&j)3u>hQ?7>^Q0 zQJ_BG_y7opwsZJWC|4=%8gI_-?vF@tAr_t_7Tzf3St~1O_5e+PR}#TSz`}H1AFgSm zicpFV1sYx@>}4e@>tP{=%+$eLE!Ph&ClK&O33>0W6Io8p=`Bij)fc|9x@PZUZD%wB zuJ%o8xaP^y3q)-EN1{4zk z{18~t@|66W8ok6z?1W1n@iccm+K3VxA3tj*M_jA6DW(gQP_XB7L!17=FK`t0Q%Y?aAj=&|`Gpzl~ zLT?Eh);!?deCWHItmOnEko|jC0~YCs05OWwuAo#Cu3f7~N&&K+b zzk*6@7e3-D9ZqDa5rn|g^PTh;8lF&o_4Te>0e2dd;}ybwAONXWi$@D8(KWg=f5 zU1R#OrE46jRN-q|$#UsbfY@+SZjkz;t;RJlX9@^w>phm(F`hy%C->Qq$3}7e3eSsd zrI{fh|$hEG5#S)r3AFezuJi*~cR16ON+?%P_F5=@|p6l$3E$d^MSib>Jp z`F!PcUPuE%Z1YjdBm}zHN`EjZ<4eUjUT4qu!Tu$dU<;8DLV*`UOrDUGrryfx%efb-efW2pT?ts1=4%mZquaF+0-IXoCf4$!orVNq%#bqa9OF%9P2E>sSZ5*IR5l z>vr%SC9&zq*3@v6kahVG@3ouqPCyr>nY)AE%s778S%ng&B5)vRAa9#%28)5exRy^O ziEadwFb#nR$)4ZX2n#ze?nzQ&=z)f~Rx+!KwVVhF)K!MRh3o-OH7na@wkAI=I5zxN zh@J`~o|=wQhCP`DJrAJm=AxjMcjYH!iR=sPVg#5$f%@oO4=b5#FV7L4N*~R;h-6MY z`8)tKQNp}G@5@=q=DLY}D#eq3X90h(aJvJ5<3umFg}Q;r0D#7X{aj?FYp~Y>1{neUN_= zA7PDwL?gW4NnIdpP&7Y8U^Yq!&zk@xs8YRmKobkxzN4abu8rc2O|tH6J_2(PFkL+AK8wR^VgcI@ z)~EAkAkhpTh(Ig?QIB76qpv5>jgi{mBQkcukUtTKqd-HQ&mL&>kV}JG(k`;57H!Fo zZj#-TMApSFo;; zM1i_Bs{h90@*AKNO!WYj^0x6vOHG@>tz>R5BCcJAQd)-Vfwhq0$FsyiB&8>b3cV?E zs{d4mYAr`#@qN!wHXyI$i{&?a2w&o*!>zs}kW7IWQ?=Ejn5Qxi(EB}5E1gsm2|Tew z8h<(sffWe+?WrKYqbfzOr~3l{YvnuvrMjwZ5Lk&o(2Chmn_xrw6ERp;pA-^oemjG} zDzu@*K_Ai-fN~@L=ps0M6kkoH-Q!v3SdGAd9TnG@mu__v_a7PS`6P1s`ZEP3tU;B0 zdQB~0!1h{As90)42tON1?am%SAcX>TS93>ZvX=khws31s3g0xHnp%nwSc?*@?^l zQ5#ap1sfjMoQ`j1sgXv3`pJ)d&ahT<^$@+Z0i!by@TznAaKAIElum)V$PW^pv4-k& zFaWS=9p9FvgX$iF+_v5+BxBd3l)5caEvD9?+2ZyzWe_h_>1cblF>gRA1{%gJ+5z^F zwA?^S{&&9FNOfwmP)aJ1l7Ucp{9iocKU_pp%XJYcqL;zEwGn~0WqX!lT?D9R&{UhG7|q|vm&T^!qUGe8@)VKgsjyPrgusmY+%ndf?}UiWW@7`tZk3va+M?Asu=#0)ZR~R7crpgrf}s&nA~Zj{&1gs<~VHIoLs;B@zFN z#3C1=Ke{U{Fs}mM&kz@)B_DYJ`_=yi5hqY4f$|V??G|;BN&1{bpFn}6veA4LORe}5 ztSka;M#%ot4Cb+bZ}_YWphMmG)-26DcLF7CBG49u92|E;MUZYHZ{3qt6O{y)H26$& z%$S7;WfLeLA=CY#stnS-CZ_D+1P(c>jD^8ZY|0_fR)o5*&Z|O4$n_G*Z?Pf|^%oU6 z%WXFNw5@k8fwm#!_)DElDtX7Q;?ljOkZ(Os$CIZZlt-ZL2$}62`Uab}+)zz*4NFd? zdQ!U#sC^=7Gl6y>)T90(D*<`O2=P#->t(*v=$82eb zXaiKO5P?DK#qpa@1Q;#&rFYsy&4EN^CfYfLG?1kJx48#?Io#`|v}4YcDmM>W4rP0{&Hn zT!Ntnf2Se%a6jQ*b0>&8kx-1l6!V+ym<^phH}EcVWF(KcRMWT_ff5A9?XKXNYmiTV zgI+2RrhkPU#W?pN`8$zNiomQ1{oPpb%zBBODrx36IK5qGi$EEZV6mZ$8Ev*i?Cqs* zMc}k2LlqJUhY*k*{>r^%j!!^nq?N^_?c4!@!w6(Jyo_N`5P1 zf|p7Sg$=*QNKwKG1PVR|kz%18q{D`7&`hXX6NQ(2-7F9|iNLz+q7xd~ikVoti}zfq z-qte+oI>E(j_**6(4%yp(eP50L^E(|vSZ128i9ZH8md?>sSczUwiNKD+QDM>wi3o0 z@zNOtPI|?d;-W9__(imvGI6Anx~;v9z*z*24u1ZX1J0sE%c8gqP7O(W5IBdxwW_{) z%!W=auVUT%i7>kV);1{N8Uh-v z=i9Q*r!zp5Xt5jbe1Y|$@ERP+rRxaTe@UFj>OliHVsLBmSH2^TRwS{fi8s)Om+!N( zSuS<95fkXJ6~D0dy*rMnbQ6I$sfGS5H4GzZA%lbxJHEY*=;2$9D&0brqSI!d!$ubH zUu$Y91y1qJO`hA-7lGRd1QeU{n%01n#+oiH;*lO3etHQ4cM$k(+;7#aRjLPwY%M*) zfuO=X4Fv8Y;M3*IV`j95IjtJ1k_$IFC}SsN7;)=81m^wz4D^8XRyPw5aEioHD5&49 z-3Z)AKs(Ykoq4IyXP+KmJ59Jqms~r2~K@$T1ATa*^Hy38Ky7JDsScA8Fc~VF-0uK>zyyKtA za>)R%9%(M^=etIbB6=SJj}VyOqwp(A2>2lpR~#j4=tBo$bkI%h2m~G@5SOrKBbyIx zaZ99ez}MOUSU-~$&=Ur{?az%?^%P6ifd%~PHK^BT1xolAff42(&#}~KmnqzOaK145 zMf_X@DiH|EXbxrB`u!JuPY_fpQxit7n6MdvrwGWWHIiF+OpRFawnmva_XVxcI*Y(F z1TtbSvz;&C*JfG-Aff0MUmb$lY%D;a3W1F$91?M$1boR>c0N%+65ZwhBJdo6HQAS& z*b1k)Oq`I5#RE;JyEz{bs7B!G`KWx9Aa9FD7EtLRDJl|95(mkpBsQ{Y5ZE5yQAns% z#$gqo50dsLq$r^lf!!vNY0QQnC&gVtaR@I_iS9zSYt$j|%GW^;3n<`Y75x;bR6K>Z z8Hw}MGbo`Rfswry@m<3-L8w%k%O{ZIbN74%ULY{gdYL0D(eC;5F%?KC$rIO>_JxH6 zAZMa#(g+$5vR!cW4ufnSDxo-@9*&QlnUqEZrp{l$2cw~uI2bB7OG%H@a#ZUj0)ywg zb7vXbO-mefl`D|M4ls>#NH2P>5LQJiCn&ULg>+EWHaGTjmeM`;dnw z@{0$@AG<6Nc#Xi!g(kcabh{?HpE7uk4;l;*c!R(omu{a~=P=C?*Wjf~`6QLJFJS5M z7J=}?Nh?`@HO2k{u_$XLY-p9wY-naSL@(r{w`(tYZ3lp|tgbtM68*Jw0tr9829VgwQ`o;U@ycUNcEArODUDv6&lPHjsRX4)?(9 zRnQYjoB4&niec(i%my26v1BO^P65-lkEcM4N-Gr=WQtCN$Fopay%O)XDLYM|R$mNu zOe?``N?jJtqS3of9MBFH@p#K?*L)^&NDurTmk^^WxFYAqk<4XXZ_+2L-~pw6zD>n; zs9$TroG)%1!rRVmk?uvq_$UtW8!-);Hi9XdIVq9(p~FXElk%!ytmfQ<$fybC^JA$I zt3%@n!lsg>ZLq_bF7*P|L zgSZZYJ)=MNLq?x7wU!sNMAum?o@o}J=9i^`6+iG@=okae)V%sNF^CW1{^D3(cAW3a zfn9F@L^+0n`Lry*m`zHv4dNY}L*k*gyl}!jWIA!CR&#jHPS!`vyp-qEOp+1w9#h`n zjPpCjNHF?C8+lAkeu@7nqC9Q~>g4059b+sQ+wc^gJ9>`dSXJyG&XQv)T~Rj^!RUr> zb!WL_ctG4VC~Hps>E~T!OgU5YRPl=2V~Sr+>t{Jlap)oCiSSLNCz^4l)?(hFL{>ir z*ToYfB5;>t%(b`ZYC4O`lhT? z*~X|~WXw5J^V2clE2{;w_hS1g!T%El_T+?IL#8WdYP)yXnZWwB$p8@_f1E*bsH=CYUa{Equ<| z&JZu+98BRCn1Ks69!17RFh0IIddvr=?ZrclqO-jB20C?0K&CrqYMMV?xXA0LUd+=H z1vL$HTU>YpnI4>}wOQ*-(vkK=<4)quX621Rpm%c$GPZ&-)pDz3F|~l&#OyAtx!HlRHn$Ggzf9h1ZBgwYTV|}!* yVAShJaou{g6FGYD3(uVZRoQ!#<0u%n<3Q( literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics new file mode 100644 index 0000000000000000000000000000000000000000..b80014da696557fcd42b45253376560007316265 GIT binary patch literal 15058 zcmZXbd0fre`@oYnNg88|kccw2sHjk`dzDabx2`OuTcU`RtP?`^G6Giy2u z|859SGIG^XQc_wG=pW?o5}^IqV+e4(#md4Tdx zS54uMLF-m{FIefj-g{+0#|;50My{&DA6NSL2L=VGPFHsADE#KTJSaeIx{9j``3(Q% ziv!fBtE^fj$aE8AZV2dL1RoTr@Lj&Z&wu600A*Y2wxLs&Jq&WAa)anA3NJ5b{}m8B-f06I~iA$p!B`78mLsonS&c^8pT z=|P1Fz^s3#js`%J0`)mEMF^EZGsvwM0cMSPrHX*@vqF^c>Zca~+7$SFZEU>oI|UYYddA}5YGt5Gv=byqtrjl= zpeqGFeUx=3!ij{lb4i@3N|P6LC2|EoHr7W20386te^JE=@M4r6c^5I-{>O2q((E~J zegoha3Vgg5zlyv>fW|tEwCl|4G$euqXB|ftxNZQq3BOX{!}Yi2Ecv99JBFi|+*d4SZguWI?Fj(g0a#M$kAU0R@8n%Xg2#DG4O#v6 z0|4}(!26c3cUT;xp($4o@K&{nuf9lC(sADfKu-V;%`;|3OCu*P5niGayg?YTVbtCA z0Q3T2?h#CHY5eAaEL&auf6irba5mCB4?u4KREJ+cm9D)MUI7V#)zo1`g1!>z972ie zQlKWU*9VdW#I3I)&53rz=)h$M(M#%mp90VafcUvTP{Q-U85oD4uieq;<)88a=nH^- zuX^%XM8e0IGK|COdTI_)DQI;!0Q3M*uP$RH>ZCbx5BXT4((lVLmz)!{%>d{}f$Aso z#mok$p5fzBrA=+LTB(F_;Q;gpps1k&CG;!exfEQ@Q$uf?B>)2ehzfRKso`We7Q2S) zM&o1TuS9}h?kxcH0a&p0Pn0k@c!2N{wIO6uBl(d4MnkUyFc1JU$5X74IZYozE1wv> zamq`U-m-u93UsliwWPW< zKBx7e+~}}gx@Iew#jXI6LfE<#N1irUDW6p$`T1mE3E(+4#3kf5|of?=7JK!Qg0C7 z5kM!{1Aq|#Tzc>Y-I~@cSOd*?|F9`6aI)ANtfof-5FGt0Ndlono&Fx?hi&GlK*D^G zy&Ap{B#Z(e;reP+>Fk+0l(26WZvsk%!vPo#z{2{>CT8^A8T4a`(fhakhUwk1`Bwl; z0GQ!57A4#pBq#?HXn)I0q8-Vlw17$g#sJ{z6DRymZMgSr9oleM!%iKYQC9)K%`yHP@QBAw<$LbMOogEZss=KxqR;5hX-k=HOKUk}jsaa9i5D8J|a)ftL8xpTB1HcM^!Fn zmG)s0WZFkfutZ}j=xiY>1-=Rdz#4!tJGbV;HY3wbqob2ZIPS*daPx!%05$;3HPvMm zC)2vI9SbOibZA=B(!{8V089p8aEtIOO>bF`G~NW_5@-V+35_1SYXQ-Y+As&_wDjoV8UVxqXtyN$!0s&@@OmBE5Icze zNK`6*Oint&t|6g7Rkd9k0wyN6(5-QvDMT0n@}{ufIu#_SgjcilmKlB*-Ve)3-06i_ z4=s9ZtF#3P&1p4gwDCMXr%$NSErLiW{6`s7vIC$}X(viBpL&NIZI2^M8h4@zfN9JI zlgtnnhjAn=VX()3;6`8X1bO0`|k0-tqDo$o~ zpY97J4)FuhSod~F^mz*sW-tlD5f2tD*+g@`FD2w-q#e#5i~?XL0Oz~766X^M)`y}o z4kz~0MlN{C(fa8%02}~F%6*m&ySHr82AZN&rAoBH;r13*agG3lc{-xel5HGFSb}bK zsIHF%39|rLZN0-1PKvVW_Pn>AsKOcN(3-~rCk4P$djLk-WfAX7C)M?cD+C)He)uvk zIROyz)4ZH{$?h}vQhWjCl0&hFE2ty`Kw+W7TBS_BlumPE^eF|_gTvKZq$Y**b_QVX zH!7UC)w$jtTV-MvN^nf;;|_oe0Mmj$G_go~wD4w}yd5Wyqee>v0CE7#uXjMeYb|Zr zR3%Lu><-PTp#aPVz*x6c3a4F}M=meXq@K8jIyAjE0KgT1!AC}596Z+3BQB9}()KDb zhq$$Q6+0=;0bpRaR^bC^E_t5dgZ)%B4iv|ZqYisV^c2SDj)Ge&yZdU|s}Y)CD`z|AOJpA8b+DDb+W zumGEYzZ#E2{7^J{s&3ET0JsD2anS-yjey#%BoB#%S1DD|HP4(z0-5iAJ1byM{LaY5R>aaG@(+laD39=_K4WM54hV zAYl;zr_W0;TX)z~J*b4g@OUf|sUZhaIU8kcH=ot z+O(z~y{CZZlDE&qa;?$N`HTZ2buWz`?Q+ zRXSb3YdSNS+hF}GE1*>XjLW-)aX5aRXKPvwc0O^ZfoyMG4S?gJL8wwf`-wgM0uFYu z>eu&R!x{i4^zaE{6DYZb?;4pg+y>=bX7pMBX7|*^R(Vv7cklE`xZR6;j4;Vt${|ToZ}ZEnbuEOn-TZ4W|96K0Miu%uxlji zQN^hQ7oJN7Q!PQIU;s=awXkcPZsT_XXTEWjjBVLa+yH?7!Y{b~B*T(KBxHtT#Z75D z#UyM5pnFU*O2}xxw#)p4mug~DN9LtX0O*_^gsG9L$9rq0t`Ho$&4?wH-NA+s3cPCh zTctA`{AB4@_zC4~9p24Msf}fk-V729s;{CAiBI|EX8K{S($LA@K*AP~@H6#mK0DI< z)xa+|&s5e*m@vaA#{6Cg1fg^ejdKm)(UM-C%zJfE@r- zC$B{>ojb-)Sy}z?94Br(%EsjF@SX73vjqLB=_yIg4 zkDf9GmBh6sY)K3Q;9d$|=E!aycS4nNcHk714At5VD(wZJ%;+C%#<|OqgqNt%IZfEe zBr2sWm%;%^^ZPWCZ30C_{Hi6ZoF3_jmn?g+!2JP0erm4)Sm5>n zaP?;idg*Q$%?A>PEJqx>lEI^0!G`?+gs(V=UE^^ozm?7&ff6J-EGG{HFLR0EYlLVXcA{SJJN2W?=W0bds|9a0Gw{?&ENUyM9NDutcSdR#Zu1 zqt24=C;}Pa7lGI3?!a83R1zucUGJ&b&9C@2BiMP&hFWG!x*&0oO@{1h}$q`qG zgL73oP*{v-QAxiGtdYe~;CYF5Ftfq=Yi_fk2=&r~-nbb{2Fzw>pkn}-7$veuyY!jN z?=>F6qEBov&ti=%mIBY7-c4Z=T>4+&*LKB2aG=;eD@p{n#sM(fhIDgcgUh(AAy}ft zOZi5&YfKRU#{tlK;mG=u%eZ77>5_L?529BCzXI?l1)e4a#S=@3gvry+5_3o#%7kMq zYh-GFl3i0c?ZyMJXH638tuEq1{v7BDzs45V-f{*Y0f5n~=CHuI%&q75OC_y5a2DsU z0dN8U&5N}tVZIShzOn|qwv$+{B&|VE=_CcplKO8Z9hlhQJByATYC{XURiaVKN;DCG zi*qjuA3y=0M83_GN>L@zxz3kn(AtYC9aQ6MX!#C2fJ;ntO903Kz->xB zs&wQhy>%frJio(LvOL3pOaNv_J0FE6;BxFG-|k;CB*rY zY9$iNg)J2BDkYlf%K{PAfweBjZW)_-==V_h{t*mUwU{RfE)n+u)@17M>)-Wj``ZJ{eY0bZ*L@29uY;-a8*0R94C_u|voHO^%5`=yug)K*A7am7n^ zo<0wNgZWNW>B>MFEh6EiHr}jD44REV!UX^fKPRBkSN46yIJ`>1z9cRTo&mr`05oUK z#{#-3r_U~^glIuedV?iVzsWX%TmX#C$k~*};nFw0kyXZ`O5$8CACQnofhP^|6AW0Z zyzRwj+$(r4OpJc*%%LATpSsK?H zBwPi+)2qY>Zm?V)=a9Bb>*3v8jDtk&?Pmb40kCY%6_ilgN{fKl@Wu@#i0jQR18^OH zUHVg-SoeP1gkQ9AtqPzlX=IkEF9Em#z>E>wNj{KPSvrVLLL#AZ6-tm;%pMLvJ^=RJ zz1gne@@x_Rp1_+ec!MReTEQ-cNc)tk^K><8NnZDzagU)n!pykSuX$}&Eq?|wb6ubhBR>8OU{p9odT5#0Z1v) z!;1U-gYT_xJ@7H7hl({DyY~P%xYH1gZl>QNqq+2U5N?&SODd!!g=TyofCx9drf_Mz z#*Ka(f@dJN!k!WU9suyWg3LUc8uj6Ppi~QG3)_rK^?kM^76EX;$fXyX4|U(Um#X72 zH9Y5Jdt*r2#-O0HPnUxc+?4<~ND&PT@x8waCI7 zfJXq>^p|7czGU+StJ+M^lWqcDGuB!QMhadk0YH295OnL$DBj54@u!v^#fnD&JO*Hf zWIIc5d4~Y{EQdt;9sdf6m(gP>sPqJY&U1$`8{{fWxzW|5@U-hW&r=P6QVNv3=#tH9 zTCQnF_D34&>TFDJMNYy7Q0X54zI$wBM$0uy_;bcnoEBvxSLKy&UV?%|gl1R%n zNAS+~emzFo^TF5dAmJ$hSGL%(z9jE@jNc?ySJ5+|&}O{WB<2F}41kiT4VJKx$vY?W z?f$*K@Fw~|+q2>O+7KFt=K$nC%R!@c+7BGnbFh&qV&!btC+^P_xk{c* z>TLI}03fOICTlbDo;Lgp^nsrrJWmdN3@W_OnrViBHk? zz=?8$*{<=10*|Z@oZ0>#&il+jD_Dt_zm`4Ubb(bbZL z0K5aBAo@)R^d-4GhbLdHP*u>EM9Fo208|68=TZtOcaq*7&O9~#W3+GbApmLsuqYIv zgcY@XpwwOFUOJ-BcJKE9m_B%s&m;tH=F56*9+s%+PM(C#PzX+<`&0kg>YS;(bY3OuaqDcb|rr}EXX!cd?A;~?(Yn^jyL z16rq{4J%6MhW>v5gPX%ZLOlR?|8QVaRK8&^fh*>R zrlC4j0DJ@Bz)B@SU7CEm`p|qJKz$E%t2hXs`27pO{t@H@C=l8}pCM2QaLOdfC(}{^ z32gwZZ>_`B*s_2h!0QWW#}NRJg(0N?dw(N4`-h0TVePHDs#qH?RU(Pou>TTD)py?JfI8g zeDk6nAg3c|e3Z|VS4juiJCz# z!aUfun5Sm#WQ^&2@6k5EXmVzf>QJn(gJQla)bd`ch?YJ9rZZ=x4d(f5=|E9e%rBf7yy_q9$Vad7FY4Ac*tA6Xyp1mggPPh@eqR(+JzdwOz4+GjZC&i?iX_4WK~R*F=EX~6VCrs$2% z1Z;?L(|JtmJ7V8a{5_l1X>Vj6cI>th=WE0Re#clxMl(HsDbmMlfjzp&6xBSlM_Wc8ph-4L9wO45aj4{ zrar)sXhHHV@)NhGo=ifz&M1bf`~YSkXP&F!{~wS?j}vG(*8E4evBq~os1FP{bK{Lc zK9h6o6#ueUBjCiIRkBQA263j;McPKpgV+F? zFto?>UgfFjVFmUWa;E6OaP)Q5nD&+KFvdr5?NK!7HiR>mmx-|EPP8xl|H+*zse3`r zP|g%yHNt*y*oh_#t!=&uDi-LGt7c)&8F3~{cLdg4?0x=lr6CV}pcu2*7<4n{Owmn4 zte?mRz694dA>*EW4OXMTVVpVH#%F!P03M%)33!jLh}GB$a(?4X`ikL4S$yK){+#-n z=Z=TW#Riz+oGISuE@eAKq#wVJY0m?X=}UhFW&~%lZ6eTb@$K`paUS*^kFf=%z>MU~ z-VLV0%QSbA>-h54Kpt(;2hAQ~4O^6`uSap_(P~%pTiiZ=#A>v{J;bB+bOOj3&6$es hv~iI7iT2JTs1*?DvRj>qO*u!+#+i`9CxRITZi^ literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e21dabd3920b518f627e029190a11b1b4f03fdc4 GIT binary patch literal 6148 zcmeHKF>V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """Defines the forward pass through the block.""" + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + """Bottleneck block used in ResNet. Has three layers: 1x1, 3x3, and 1x1 convolutions.""" + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(out_channels * (base_width / 64.)) * groups + + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """Defines the forward pass through the bottleneck block.""" + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, + width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False): + """Defines the ResNet architecture.""" + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 16 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple.") + + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64 * block.expansion, num_classes) + self.KD = KD + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + """Creates a layer in ResNet using the specified block type.""" + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + """Defines the forward pass of the ResNet model.""" + x = self.layer1(x) # Output: B x 16 x 32 x 32 + x = self.layer2(x) # Output: B x 32 x 16 x 16 + x = self.layer3(x) # Output: B x 64 x 8 x 8 + + x = self.avgpool(x) # Output: B x 64 x 1 x 1 + x_f = x.view(x.size(0), -1) # Flatten: B x 64 + x = self.fc(x_f) # Output: B x num_classes + return x + +def resnet56_server(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-110 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool): If True, returns a model pre-trained on ImageNet. + path (str): Path to the pre-trained model. + """ + logging.info("Loading model with path: " + str(path)) + model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs) + + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + model.load_state_dict(new_state_dict) + + return model diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py new file mode 100644 index 0000000..8055cb0 --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import torch +import torch.nn as nn + +def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): + """ + Creates a 3x3 convolutional layer with padding. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride of the convolution. Default is 1. + groups (int, optional): Number of blocked connections from input to output. Default is 1. + dilation (int, optional): Spacing between kernel elements. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + """ + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def apply_1x1_convolution(in_channels, out_channels, stride=1): + """ + Creates a 1x1 convolutional layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride of the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + """ + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + """ + A basic block for ResNet. + + This block consists of two convolutional layers with batch normalization and ReLU activation. + + Attributes: + expansion (int): The expansion factor of the block. + conv1 (nn.Conv2d): First convolutional layer. + bn1 (nn.BatchNorm2d): First batch normalization layer. + conv2 (nn.Conv2d): Second convolutional layer. + bn2 (nn.BatchNorm2d): Second batch normalization layer. + downsample (nn.Module): Downsample layer if input and output dimensions differ. + """ + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): + """ + Initializes the BasicBlock. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(BasicBlock, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample + + def forward(self, x): + """ + Defines the forward pass for the block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the block. + """ + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + """ + A bottleneck block for ResNet. + + This block reduces the number of input channels before performing convolution and then expands it back. + + Attributes: + expansion (int): The expansion factor of the block. + conv1 (nn.Conv2d): First 1x1 convolutional layer. + conv2 (nn.Conv2d): 3x3 convolutional layer. + conv3 (nn.Conv2d): Second 1x1 convolutional layer. + downsample (nn.Module): Downsample layer if input and output dimensions differ. + """ + expansion = 4 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): + """ + Initializes the Bottleneck block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(Bottleneck, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + width = int(out_channels * (64 / 64)) # Base width + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride) + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + """ + Defines the forward pass for the bottleneck block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the block. + """ + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + """ + ResNet architecture. + + This class constructs a ResNet model with a specified block type and layer configuration. + + Attributes: + conv1 (nn.Conv2d): Initial convolutional layer. + bn1 (nn.BatchNorm2d): Initial batch normalization layer. + layer1 (nn.Sequential): First residual layer. + layer2 (nn.Sequential): Second residual layer. + layer3 (nn.Sequential): Third residual layer. + fc (nn.Linear): Fully connected output layer. + """ + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None): + """ + Initializes the ResNet architecture. + + Args: + block (nn.Module): The block type (BasicBlock or Bottleneck). + layers (list of int): Number of blocks per layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(ResNet, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self.in_channels = 16 + + self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.in_channels) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64 * block.expansion, num_classes) + + self._init_model_weights(zero_init_residual) + + def _create_model_layer(self, block, out_channels, blocks, stride=1): + """ + Creates a residual layer. + + Args: + block (nn.Module): The block type. + out_channels (int): Number of output channels. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the first block. Default is 1. + + Returns: + nn.Sequential: A sequence of residual blocks. + """ + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride), + nn.BatchNorm2d(out_channels * block.expansion), + ) + + layers = [block(self.in_channels, out_channels, stride, downsample)] + self.in_channels = out_channels * block.expansion + layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks)) + return nn.Sequential(*layers) + + def _init_model_weights(self, zero_init_residual): + """ + Initializes the weights of the model. + + Args: + zero_init_residual (bool): If True, initializes residual layers to zero. + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if zero_init_residual and isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif zero_init_residual and isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def forward(self, x): + """ + Defines the forward pass of the ResNet. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Logits and extracted features. + """ + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + extracted_features = x + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x_f = x.view(x.size(0), -1) + logits = self.fc(x_f) + return logits, extracted_features + +def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-32 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. + path (str, optional): Path to the pretrained weights. Default is None. + + Returns: + ResNet: A ResNet-32 model. + """ + model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs) + if models_pretrained: + model.load_state_dict(_load_models_pretrained_weights(path)) + return model + +def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-56 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. + path (str, optional): Path to the pretrained weights. Default is None. + + Returns: + ResNet: A ResNet-56 model. + """ + logging.info("Loading pretrained model from: " + str(path)) + model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs) + if models_pretrained: + model.load_state_dict(_load_models_pretrained_weights(path)) + return model + +def _load_models_pretrained_weights(path): + """ + Loads pretrained weights from a checkpoint. + + Args: + path (str): Path to the checkpoint file. + + Returns: + dict: State dictionary with the loaded weights. + """ + checkpoint = torch.load(path, map_location=torch.device('cpu')) + state_dict = checkpoint['state_dict'] + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in state_dict.items(): + new_state_dict[k.replace("module.", "")] = v + + return new_state_dict diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py new file mode 100644 index 0000000..cc340b8 --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +__all__ = ['ResNet'] + +# Function to define a 3x3 convolution layer with padding +def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function to define a 1x1 convolution layer +def apply_1x1_convolution(in_channels, out_channels, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + +# BasicBlock class for ResNet architecture +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + # First convolution and batch normalization layer + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) # ReLU activation + # Second convolution and batch normalization layer + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample # If downsample is provided, use it + + def forward(self, x): + identity = x # Keep original input as identity for residual connection + + # Forward pass through first convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Forward pass through second convolution and batch norm + out = self.conv2(out) + out = self.bn2(out) + + # Downsample the identity if downsample is provided + if self.downsample is not None: + identity = self.downsample(x) + + # Add residual connection (identity) + out += identity + out = self.relu(out) # Apply ReLU activation after addition + + return out + +# Bottleneck class for deeper ResNet architectures +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d + width = int(out_channels * (base_width / 64.)) * groups # Calculate width based on group size + + # First 1x1 convolution + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + # Second 3x3 convolution + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + # Third 1x1 convolution to match output channels + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Downsample if provided + + def forward(self, x): + identity = x # Keep original input as identity for residual connection + + # First 1x1 convolution and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second 3x3 convolution and ReLU + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + # Third 1x1 convolution + out = self.conv3(out) + out = self.bn3(out) + + # Add downsampled identity if necessary + if self.downsample is not None: + identity = self.downsample(x) + + # Add residual connection (identity) + out += identity + out = self.relu(out) # Apply ReLU activation after addition + + return out + +# ResNet class to build the entire ResNet model +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, + width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer + self._norm_layer = norm_layer + + self.inplanes = 16 # Initial number of channels + self.dilation = 1 # Dilation factor + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] # Default stride behavior + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + + self.groups = groups # Number of groups for convolutions + self.base_width = width_per_group # Base width for groups + + # Initial convolutional layer with 3 input channels (RGB image) + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) # Batch normalization + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Max pooling layer + self.layer1 = self._create_model_layer(block, 16, layers[0]) # First block layer + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling + self.fc = nn.Linear(16 * block.expansion, num_classes) # Fully connected layer + + self.KD = KD # Knowledge Distillation flag + for m in self.modules(): + # Initialize convolutional weights using He initialization + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # Initialize batch normalization weights + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last batch norm layer if zero_init_residual is True + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Helper function to create layers of blocks + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + # Forward pass of the ResNet model + def forward(self, x): + x = self.conv1(x) # Initial convolution + x = self.bn1(x) # Batch normalization + x = self.relu(x) # ReLU activation + extracted_features = x # Feature extraction point + x = self.layer1(x) # Pass through the first layer + x = self.avgpool(x) # Adaptive average pooling + x_f = x.view(x.size(0), -1) # Flatten the features + logits = self.fc(x_f) # Fully connected layer for classification + return logits, extracted_features # Return logits and extracted features + +# Function to create ResNet-5 model +def resnet5_56(num_classes, models_pretrained=False, path=None, **kwargs): + """Constructs a ResNet-5 model.""" + model = ResNet(BasicBlock, [1, 2, 2], num_classes=num_classes, **kwargs) + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k.replace("module.", "") + new_state_dict[name] = v + + model.load_state_dict(new_state_dict) + return model + +# Function to create ResNet-8 model +def resnet8_56(num_classes, models_pretrained=False, path=None, **kwargs): + """Constructs a ResNet-8 model.""" + model = ResNet(Bottleneck, [2, 2, 2], num_classes=num_classes, **kwargs) + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k.replace("module.", "") + new_state_dict[name] = v + + model.load_state_dict(new_state_dict) + return model diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e21dabd3920b518f627e029190a11b1b4f03fdc4 GIT binary patch literal 6148 zcmeHKF>V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution + self.bn1 = norm_layer(planes) # First batch normalization + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution + self.bn2 = norm_layer(planes) # Second batch normalization + self.downsample = downsample # If there's downsampling (e.g., stride mismatch) + + def forward(self, x): + identity = x # Preserve the input as identity for skip connection + out = self.conv1(x) # Apply the first convolution + out = self.bn1(out) # Apply first batch normalization + out = self.relu(out) # Apply ReLU activation + out = self.conv2(out) # Apply the second convolution + out = self.bn2(out) # Apply second batch normalization + + # If downsample exists, apply it to the identity + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add skip connection + out = self.relu(out) # Final ReLU activation + + return out # Return the result + + +class Bottleneck(nn.Module): + """Bottleneck block for ResNet.""" + + expansion = 4 # Bottleneck expands the channels by a factor of 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups # Width of the block + + # 1x1 convolution (bottleneck) + self.conv1 = apply_1x1_convolution(inplanes, width) + self.bn1 = norm_layer(width) # Batch normalization after 1x1 convolution + # 3x3 convolution (main block) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) # Batch normalization after 3x3 convolution + # 1x1 convolution (bottleneck exit) + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) # Batch normalization after 1x1 exit + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Downsampling for skip connection, if needed + + def forward(self, x): + identity = x # Store input as identity for the skip connection + out = self.conv1(x) # Apply first 1x1 convolution + out = self.bn1(out) # Apply batch normalization + out = self.relu(out) # Apply ReLU + out = self.conv2(out) # Apply 3x3 convolution + out = self.bn2(out) # Apply batch normalization + out = self.relu(out) # Apply ReLU + out = self.conv3(out) # Apply 1x1 convolution + out = self.bn3(out) # Apply batch normalization + + # If downsample exists, apply it to the identity + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add skip connection + out = self.relu(out) # Final ReLU activation + + return out # Return the result + + +class PrimaryResNetClient(nn.Module): + """Main ResNet model for client.""" + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(PrimaryResNetClient, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling before fully connected layer + + # Dictionary to store input channel size based on dataset and split factor + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, + 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, + } + self.inplanes = inplanes_dict[dataset][split_factor] # Set initial input channels + + self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) # Fully connected layer for classification + + # Initialize all layers + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Optionally initialize the last batch normalization layer to zero + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + """Create a residual layer consisting of several blocks.""" + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), # Adjust input size for downsampling + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) # Add the first block with downsample + self.inplanes = planes * block.expansion # Update inplanes for the next block + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) # Add the remaining blocks + + return nn.Sequential(*layers) # Return the stacked blocks + + def _forward_impl(self, x): + """Implementation of the forward pass.""" + x = self.layer0(x) # Initial layer + extracted_features = x # Save features after the initial layer + x = self.layer1(x) # First layer + x = self.avgpool(x) # Global average pooling + x = torch.flatten(x, 1) # Flatten the features into a 1D tensor + logits = self.fc(x) # Pass through the fully connected layer + return logits, extracted_features # Return logits and extracted features + + def forward(self, x): + """Standard forward method.""" + return self._forward_impl(x) diff --git a/EdgeFLite/fedml_service/data_cleaning/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2dea013af7aec74104ee244974b26cd37e3bd569 GIT binary patch literal 6148 zcmeHK+e!m55Iv&>3;NJUAM*+J4??M5;0LJfDu{~}Ti^5Pp1G`aai67>nUb7jF4@e1 zCK&*-`tFv%0>GS2QFJgNs-84mg!7Cjw#FBhxa;=YezTic=oD4D_bY7ihF3h}X#cyO zDKYPnX-IA`elo7-Z^PI?Bm=(;==o6C6l-8}w4;Nal>kJ2Mw@V~wS>l$z#7;bX`zUR5pldiQ+@z%R})Z_I0jb%pFo4vrh(+fiVM@K3(Yjf6u?nWRX7($yqXx44fGQ zGFv~aS4=AI)<5a#U0c|m*i_W7>43uV=n=q)o+GzrbbHY;{c2!y)Kye(;lz9h7$Mmu I1OLFl4<%G6Qvd(} literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e21dabd3920b518f627e029190a11b1b4f03fdc4 GIT binary patch literal 6148 zcmeHKF>V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4&C`P}G|{ZI9(%oG z%Tv643&7UL;TBi{SkfKw<-^>3-+g8m6)_^6XMErZ2W&CmVV3=Sz`56W4bFDLf5STt z!{^KHdB5LznoJ5v0VyB_q<|DSfdW-jmsck^OGiinDR5p2`1hgF9ed%J7@rOf(E<<` z42N+Zy#%p&fY=MiL`G~;h30oST!G3H(PZm7Ps^K7U{5_s8I?? zfvEzQxmhNmW~CkMmmF=#O;I`KG&Z9uy8yj)_r@ hx$ts)6G@rZe9rw|I3@<2@t_m+GvK<&q`-eG@B{tD7N!6I literal 0 HcmV?d00001 diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store b/EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e21dabd3920b518f627e029190a11b1b4f03fdc4 GIT binary patch literal 6148 zcmeHKF>V4u473vpqBN8#_Y3*K3XvD&0ha)gQXo-8^jGmNp2m#rBIrU#8Z?%?v+MQj zYNt4#&CHjF?W@_s%;s>S9XU*m`}B!DRK$UBoUyTQ`?A4??aw6p?*O?^WPtnmo4?tu z_xl~UWl}&2NC7Dz1*E_Y6{zAmf4 0 else 0 + + +def compute_accuracy(output, target, topk=(1,)): + """Compute the precision@k for the specified values of k.""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk] + + +class KLDivergenceLoss(nn.Module): + """Kullback-Leibler Divergence Loss.""" + def __init__(self, temperature=1): + super(KLDivergenceLoss, self).__init__() + self.temperature = temperature + + def forward(self, output_batch, teacher_outputs): + output_batch = F.log_softmax(output_batch / self.temperature, dim=1) + teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7 + return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs) + + +class CELoss(nn.Module): + """Cross-Entropy Loss.""" + def __init__(self, temperature=1): + super(CELoss, self).__init__() + self.temperature = temperature + + def forward(self, output_batch, teacher_outputs): + output_batch = F.log_softmax(output_batch / self.temperature, dim=1) + teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0) + + +def save_dict_to_json(data, json_path): + """Save a dictionary of floats to a JSON file.""" + with open(json_path, 'w') as f: + json.dump({k: float(v) for k, v in data.items()}, f, indent=4) + + +def get_optimized_params(model, model_params, master_params): + """Filter out batch norm parameters from weight decay to improve accuracy.""" + bn_params, remaining_params = split_bn_params(model, model_params, master_params) + return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}] + + +def split_bn_params(model, model_params, master_params): + """Split parameters into batch norm and non-batch norm.""" + def get_bn_params(module): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + return set(module.parameters()) + return {p for child in module.children() for p in get_bn_params(child)} + + mod_bn_params = get_bn_params(model) + zipped_params = zip(model_params, master_params) + + mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params] + mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params] + + return mas_bn_params, mas_rem_params diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py b/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py new file mode 100644 index 0000000..c066bf4 --- /dev/null +++ b/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import os +import shutil + +import torch +from torch import nn, optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +from utils import metric +from fedml_service.decentralized.federated_gkt import utils + +# List to store filenames of saved checkpoints +saved_ckpt_filenames = [] + +class GKTServerTrainer: + def __init__(self, client_num, device, server_model, args, writer): + # Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging + self.client_num = client_num + self.device = device + self.args = args + self.writer = writer + + """ + Notes: Using data parallelism requires adjusting the batch size accordingly. + For example, with a single GPU (batch_size = 64), an epoch takes 1:03; + using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00. + If batch size is not adjusted, the communication between CPU and GPU may slow down training. + """ + + # Server model setup + self.model_global = server_model + self.model_global.train() # Set model to training mode + self.model_global.to(self.device) # Move model to the specified device (CPU or GPU) + + # Model parameters for optimization + self.model_params = self.master_params = self.model_global.parameters() + optim_params = self.master_params + + # Choose optimizer based on arguments (SGD or Adam) + if self.args.optimizer == "SGD": + self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd) + elif self.args.optimizer == "Adam": + self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True) + + # Learning rate scheduler to reduce the learning rate when the accuracy plateaus + self.scheduler = ReduceLROnPlateau(self.optimizer, 'max') + + # Loss functions: CrossEntropy for classification, KL for knowledge distillation + self.criterion_CE = nn.CrossEntropyLoss() + self.criterion_KL = utils.KL_Loss(self.args.temperature) + + # Best accuracy tracking + self.best_acc = 0.0 + + # Client data dictionaries to store features, logits, and labels + self.client_extracted_feature_dict = {} + self.client_logits_dict = {} + self.client_labels_dict = {} + self.server_logits_dict = {} + + # Testing data dictionaries + self.client_extracted_feature_dict_test = {} + self.client_labels_dict_test = {} + + # Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss + self.model_dict = {} + self.sample_num_dict = {} + self.train_acc_dict = {} + self.train_loss_dict = {} + self.test_acc_avg = 0.0 + self.test_loss_avg = 0.0 + + # Dictionary to track if the client model has been uploaded + self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} + + # Add results from a local client model after training + def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict, + extracted_feature_dict_test, labels_dict_test): + logging.info(f"Adding model for client index = {index}") + self.client_extracted_feature_dict[index] = extracted_feature_dict + self.client_logits_dict[index] = logits_dict + self.client_labels_dict[index] = labels_dict + self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test + self.client_labels_dict_test[index] = labels_dict_test + self.flag_client_model_uploaded_dict[index] = True + + # Check if all clients have uploaded their models + def check_whether_all_receive(self): + if all(self.flag_client_model_uploaded_dict.values()): + self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} + return True + return False + + # Get logits from the global model for a specific client + def get_global_logits(self, client_index): + return self.server_logits_dict.get(client_index) + + # Main training function based on the round index + def train(self, round_idx): + if self.args.sweep == 1: # Sweep mode + self.sweep(round_idx) + else: # Normal training process + if self.args.whether_training_on_client == 1: # Check if training occurs on client + self.train_and_distill_on_client(round_idx) + else: # No training on client, just evaluate + self.do_not_train_on_client(round_idx) + + # Training and knowledge distillation on client side + def train_and_distill_on_client(self, round_idx): + # Set the number of server epochs (based on testing mode) + epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0] + self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate + self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler + + # Skip client-side training + def do_not_train_on_client(self, round_idx): + self.train_and_eval(round_idx, 1) + self.scheduler.step(self.best_acc, epoch=round_idx) + + # Training with sweeping strategy + def sweep(self, round_idx): + self.train_and_eval(round_idx, self.args.epochs_server) + self.scheduler.step(self.best_acc, epoch=round_idx) + + # Strategy for determining the number of epochs (used in testing) + def get_server_epoch_strategy_test(self): + return 1, True + + # Different strategies for determining the number of epochs based on training round + def get_server_epoch_strategy_reset56(self, round_idx): + epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1 + whether_distill_back = round_idx < 150 + return epochs, whether_distill_back + + # Another variant of epoch strategy + def get_server_epoch_strategy_reset56_2(self, round_idx): + return self.args.epochs_server, True + + # Main training and evaluation loop + def train_and_eval(self, round_idx, epochs, val_writer, args): + for epoch in range(epochs): + logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}") + train_metrics = self.train_large_model_on_the_server() # Training step + + if epoch == epochs - 1: + # Log metrics for the final epoch + val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx) + test_metrics = self.eval_large_model_on_the_server() # Evaluation step + test_acc = test_metrics['test_accTop1'] + + val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx) + val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx) + + # Save best accuracy model + if test_acc >= self.best_acc: + logging.info("- Found better accuracy") + self.best_acc = test_acc + + val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx) + + # Save model checkpoints + if args.save_weight: + filename = f"checkpoint_{round_idx}.pth.tar" + saved_ckpt_filenames.append(filename) + if len(saved_ckpt_filenames) > args.max_ckpt_nums: + os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0))) + + ckpt_dict = { + 'round': round_idx + 1, + 'arch': args.arch, + 'state_dict': self.model_global.state_dict(), + 'best_acc1': self.best_acc, + 'optimizer': self.optimizer.state_dict(), + } + metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename) + + # Print metrics for the current round + print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}") + + # Function to train the model on the server side + def train_large_model_on_the_server(self): + # Clear the logits dictionary and set model to training mode + self.server_logits_dict.clear() + self.model_global.train() + + # Track loss and accuracy + loss_avg = utils.RollingAverage() + accTop1_avg = utils.RollingAverage() + accTop5_avg = utils.RollingAverage() + + # Iterate over clients' extracted features + for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items(): + logits_dict = self.client_logits_dict[client_index] + labels_dict = self.client_labels_dict[client_index] + + s_logits_dict = {} + self.server_logits_dict[client_index] = s_logits_dict + + # Iterate over batches of features for each client + for batch_index, batch_feature_map_x in extracted_feature_dict.items(): + batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) + batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device) + batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) + + # Forward pass + output_batch = self.model_global(batch_feature_map_x) + + # Knowledge distillation loss + if self.args.whether_distill_on_the_server == 1: + loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device) + loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device) + loss = loss_kd + self.args.alpha * loss_true + else: + # Standard cross-entropy loss + loss = self.criterion_CE(output_batch, batch_labels).to(self.device) + + # Backward pass and optimization + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Compute accuracy metrics + metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) + accTop1_avg.update(metrics[0].item()) + accTop5_avg.update(metrics[1].item()) + loss_avg.update(loss.item()) + + # Store logits for the batch + s_logits_dict[batch_index] = output_batch.cpu().detach().numpy() + + # Aggregate and log training metrics + train_metrics = {'train_loss': loss_avg.value(), + 'train_accTop1': accTop1_avg.value(), + 'train_accTop5': accTop5_avg.value()} + logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}") + return train_metrics + + # Function to evaluate the model on the server side + def eval_large_model_on_the_server(self): + # Set model to evaluation mode + self.model_global.eval() + loss_avg = utils.RollingAverage() + accTop1_avg = utils.RollingAverage() + accTop5_avg = utils.RollingAverage() + + # Disable gradient computation for evaluation + with torch.no_grad(): + # Iterate over clients' extracted features for testing + for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items(): + labels_dict = self.client_labels_dict_test[client_index] + + # Iterate over batches for each client + for batch_index, batch_feature_map_x in extracted_feature_dict.items(): + batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) + batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) + + # Forward pass + output_batch = self.model_global(batch_feature_map_x) + loss = self.criterion_CE(output_batch, batch_labels) + + # Compute accuracy metrics + metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) + accTop1_avg.update(metrics[0].item()) + accTop5_avg.update(metrics[1].item()) + loss_avg.update(loss.item()) + + # Aggregate and log test metrics + test_metrics = {'test_loss': loss_avg.value(), + 'test_accTop1': accTop1_avg.value(), + 'test_accTop5': accTop5_avg.value()} + logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}") + return test_metrics diff --git a/EdgeFLite/helpers/.DS_Store b/EdgeFLite/helpers/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e8a5f766af18690798d77e802a85876f1da03fbf GIT binary patch literal 6148 zcmeHKF>V4u473vpqBN8#_Y3@Bh2RA|fDnk30*M0AU&Xufw9MF61Uk}0N)wGGdv<-E zTkRC*vzhtgyq}xR%xnrL+J#|kd`=(PLuDKY$2)fRVehuF&Hgm1zMmlXi41VJ*Zke# z?RY%gCn+lhq<|EV0#ZN<{8|C_UfA?GQBev=0V(jU0KX3nPV9w4VthIfVgvwAkPgE- zW(irB!C&F5`GnK34hllFV_>vnY}k%p cqA2Sc*LdCwhs2;GA9SF82B?cn3jDPKUtz@+9smFU literal 0 HcmV?d00001 diff --git a/EdgeFLite/helpers/evaluation_metrics.py b/EdgeFLite/helpers/evaluation_metrics.py new file mode 100644 index 0000000..51febc6 --- /dev/null +++ b/EdgeFLite/helpers/evaluation_metrics.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import shutil +import torch + +def store_model(state, best_model, directory, filename='checkpoint.pth'): + """ + Stores the model checkpoint in the specified directory. If it's the best model, + it saves another copy named 'best_model.pth'. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth'). + """ + save_path = os.path.join(directory, filename) + torch.save(state, save_path) + if best_model: + # If the current model is the best, save another copy as 'best_model.pth' + shutil.copy(save_path, os.path.join(directory, 'best_model.pth')) + +def save_main_client_model(state, best_model, directory): + """ + Saves the model for the main client if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best main client model") + torch.save(state, os.path.join(directory, 'main_client_best.pth')) + +def save_proxy_clients_model(state, best_model, directory): + """ + Saves the model for proxy clients if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best proxy client model") + torch.save(state, os.path.join(directory, 'proxy_clients_best.pth')) + +def save_individual_client_model(state, best_model, directory): + """ + Saves the model for individual clients if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best client model") + torch.save(state, os.path.join(directory, 'client_best.pth')) + +def save_server_model(state, best_model, directory): + """ + Saves the model for the server if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best server model") + torch.save(state, os.path.join(directory, 'server_best.pth')) + +class MetricTracker(object): + """ + A helper class to track and compute the average of a given metric. + + Args: + metric_name (str): Name of the metric to track. + fmt (str): Format for printing metric values (default ':f'). + """ + def __init__(self, metric_name, fmt=':f'): + self.metric_name = metric_name + self.fmt = fmt + self.reset() + + def reset(self): + """Resets all metric counters.""" + self.current_value = 0 + self.total_sum = 0 + self.count = 0 + self.average = 0 + + def update(self, value, n=1): + """ + Updates the metric value. + + Args: + value (float): New value of the metric. + n (int): Weight or count for the value (default 1). + """ + self.current_value = value + self.total_sum += value * n + self.count += n + self.average = self.total_sum / self.count + + def __str__(self): + """Returns the formatted metric string showing current value and average.""" + return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})' + +class ProgressLogger(object): + """ + A class to log and display the progress of training/testing over multiple batches. + + Args: + total_batches (int): Total number of batches. + *metrics (MetricTracker): Metrics to log during the process. + prefix (str): Prefix for the progress log (default "Progress:"). + """ + def __init__(self, total_batches, *metrics, prefix="Progress:"): + self.batch_format = self._get_batch_format(total_batches) + self.metrics = metrics + self.prefix = prefix + + def log(self, batch_idx): + """ + Logs the current progress of training/testing. + + Args: + batch_idx (int): The current batch index. + """ + output = [self.prefix + self.batch_format.format(batch_idx)] + output += [str(metric) for metric in self.metrics] + print(' | '.join(output)) + + def _get_batch_format(self, total_batches): + """Creates a format string to display the batch index.""" + num_digits = len(str(total_batches)) + return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches) + +def compute_accuracy(prediction, target, top_k=(1,)): + """ + Computes the accuracy for the top-k predictions. + + Args: + prediction (Tensor): Model predictions. + target (Tensor): Ground truth labels. + top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)). + + Returns: + List[Tensor]: List of accuracies for each top-k value. + """ + with torch.no_grad(): + max_k = max(top_k) + batch_size = target.size(0) + + # Get the top-k predictions + _, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True) + top_predictions = top_predictions.t() + + # Compare top-k predictions with targets + correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions)) + + accuracy_results = [] + for k in top_k: + # Count the number of correct predictions within the top-k + correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True) + accuracy_results.append(correct_k.mul_(100.0 / batch_size)) + return accuracy_results + +def count_model_parameters(model, trainable_only=False): + """ + Counts the total number of parameters in the model. + + Args: + model (nn.Module): The PyTorch model. + trainable_only (bool): Whether to count only trainable parameters (default False). + + Returns: + int: Total number of parameters in the model. + """ + if trainable_only: + # Count only the parameters that require gradients (trainable parameters) + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Count all parameters (trainable and non-trainable) + return sum(p.numel() for p in model.parameters()) diff --git a/EdgeFLite/helpers/normalization.py b/EdgeFLite/helpers/normalization.py new file mode 100644 index 0000000..c5f238f --- /dev/null +++ b/EdgeFLite/helpers/normalization.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + + +class PassThrough(nn.Module): + """ + A placeholder module that simply returns the input tensor unchanged. + """ + def __init__(self, **kwargs): + super(PassThrough, self).__init__() + + def forward(self, input_tensor): + return input_tensor + + +class LayerNormalization2D(nn.Module): + """ + A custom layer normalization module for 2D inputs (typically used for + convolutional layers). It optionally applies learned scaling (weight) + and shifting (bias) parameters. + + Arguments: + epsilon: A small value to avoid division by zero. + use_weight: Whether to learn and apply weight parameters. + use_bias: Whether to learn and apply bias parameters. + """ + def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs): + super(LayerNormalization2D, self).__init__() + + self.epsilon = epsilon + self.use_weight = use_weight + self.use_bias = use_bias + + def forward(self, input_tensor): + # Initialize weight and bias parameters if they are not nn.Parameter instances + if (not isinstance(self.use_weight, nn.parameter.Parameter) and + not isinstance(self.use_bias, nn.parameter.Parameter) and + (self.use_weight or self.use_bias)): + self._initialize_parameters(input_tensor) + + # Apply layer normalization + return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:], + weight=self.use_weight, bias=self.use_bias, + eps=self.epsilon) + + def _initialize_parameters(self, input_tensor): + """ + Initialize weight and bias parameters for layer normalization. + Arguments: + input_tensor: The input tensor to the normalization layer. + """ + channels, height, width = input_tensor.shape[1:] + param_shape = [channels, height, width] + + # Initialize weight parameter if applicable + if self.use_weight: + self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + else: + self.register_parameter('use_weight', None) + + # Initialize bias parameter if applicable + if self.use_bias: + self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + else: + self.register_parameter('use_bias', None) + + +class NormalizationLayer(nn.Module): + """ + A flexible normalization layer that supports different types of normalization + (batch, group, layer, instance, or none). This class is a wrapper that selects + the appropriate normalization technique based on the norm_type argument. + + Arguments: + norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none'). + epsilon: A small value to avoid division by zero (Default: 1e-05). + momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm). + use_weight: Whether to learn weight parameters (Default: True). + use_bias: Whether to learn bias parameters (Default: True). + track_stats: Whether to track running statistics (Default: True, applicable for batch norm). + group_norm_groups: Number of groups to use for group normalization (Default: 32). + """ + def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1, + use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs): + super(NormalizationLayer, self).__init__() + + if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']: + raise ValueError('Unsupported norm_type: {}. Supported options: ' + '"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type)) + + self.norm_type = norm_type + self.epsilon = epsilon + self.momentum = momentum + self.use_weight = use_weight + self.use_bias = use_bias + self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed + self.track_stats = track_stats + self.group_norm_groups = group_norm_groups + + def forward(self, num_features): + """ + Select and apply the appropriate normalization technique based on the norm_type. + + Arguments: + num_features: The number of input channels or features. + Returns: + A normalization layer corresponding to the norm_type. + """ + if self.norm_type == 'batch': + # Apply Batch Normalization + normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon, + momentum=self.momentum, affine=self.affine, + track_running_stats=self.track_stats) + elif self.norm_type == 'group': + # Apply Group Normalization + normalizer = nn.GroupNorm(self.group_norm_groups, num_features, + eps=self.epsilon, affine=self.affine) + elif self.norm_type == 'layer': + # Apply Layer Normalization + normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias) + elif self.norm_type == 'instance': + # Apply Instance Normalization + normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine) + else: + # No normalization applied, just pass the input through + normalizer = PassThrough() + + return normalizer diff --git a/EdgeFLite/helpers/optimizer_rmsprop.py b/EdgeFLite/helpers/optimizer_rmsprop.py new file mode 100644 index 0000000..52792ca --- /dev/null +++ b/EdgeFLite/helpers/optimizer_rmsprop.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +from torch.optim import Optimizer + + +class CustomRMSprop(Optimizer): + """ + Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling. + + Main differences in this implementation: + 1. Epsilon is incorporated within the square root operation. + 2. The moving average of squared gradients is initialized to 1. + 3. The momentum buffer accumulates updates scaled by the learning rate. + """ + + def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True): + """ + Initializes the optimizer with the provided parameters. + + Arguments: + - params: iterable of parameters to optimize or dicts defining parameter groups + - lr: learning rate (default: 0.01) + - alpha: smoothing constant for the moving average (default: 0.99) + - eps: small value to prevent division by zero (default: 1e-8) + - momentum: momentum factor (default: 0) + - weight_decay: weight decay (L2 penalty) (default: 0) + - centered: if True, compute centered RMSprop (default: False) + - decoupled_decay: if True, decouples weight decay from gradient update (default: False) + - lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True) + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight decay: {weight_decay}") + if alpha < 0.0: + raise ValueError(f"Invalid alpha value: {alpha}") + + # Store the optimizer defaults + defaults = { + 'lr': lr, + 'alpha': alpha, + 'eps': eps, + 'momentum': momentum, + 'centered': centered, + 'weight_decay': weight_decay, + 'decoupled_decay': decoupled_decay, + 'lr_in_momentum': lr_in_momentum + } + super().__init__(params, defaults) + + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + - closure: A closure that reevaluates the model and returns the loss. + """ + # Get the loss value if a closure is provided + loss = closure() if closure is not None else None + + # Iterate over parameter groups + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + weight_decay = group['weight_decay'] + alpha = group['alpha'] + eps = group['eps'] + + # Iterate over parameters in the group + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data # Get gradient data + if grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients.") + + # Get the state of the parameter + state = self.state[p] + + # Initialize state if it doesn't exist + if not state: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1 + if momentum > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered + + square_avg = state['square_avg'] + one_minus_alpha = 1 - alpha + state['step'] += 1 # Update the step count + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay + else: + grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay + + # Update the moving average of squared gradients + square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha) + + # Compute the denominator for gradient update + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop + else: + avg = square_avg.add_(eps).sqrt_() # Standard RMSprop + + # Apply momentum if needed + if momentum > 0: + buf = state['momentum_buffer'] + if group['lr_in_momentum']: + buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer + p.data.add_(-buf) + else: + buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update + p.data.add_(buf, alpha=-lr) + else: + p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum + + return loss # Return the loss if closure was provided diff --git a/EdgeFLite/helpers/pace_controller.py b/EdgeFLite/helpers/pace_controller.py new file mode 100644 index 0000000..1e84654 --- /dev/null +++ b/EdgeFLite/helpers/pace_controller.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import math + +class CustomScheduler: + def __init__(self, mode='cosine', + initial_lr=0.1, + num_epochs=100, + iters_per_epoch=300, + lr_milestones=None, + lr_step=100, + step_multiplier=0.1, + slow_start_epochs=0, + slow_start_lr=1e-4, + min_lr=1e-3, + multiplier=1.0, + lower_bound=-6.0, + upper_bound=3.0, + decay_factor=0.97, + decay_epochs=0.8, + staircase=True): + """ + Initialize the learning rate scheduler. + + Parameters: + mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential'). + initial_lr (float): Initial learning rate. + num_epochs (int): Total number of epochs. + iters_per_epoch (int): Number of iterations per epoch. + lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode. + lr_step (int): Epoch step size for learning rate reduction in 'step' mode. + step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode. + slow_start_epochs (int): Number of slow start epochs for warm-up. + slow_start_lr (float): Learning rate during warm-up. + min_lr (float): Minimum learning rate limit. + multiplier (float): Multiplication factor for applying to different parameter groups. + lower_bound (float): Lower bound for the tanh function in 'HTD' mode. + upper_bound (float): Upper bound for the tanh function in 'HTD' mode. + decay_factor (float): Factor by which learning rate decays in 'exponential' mode. + decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode. + staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode. + """ + # Ensure valid mode selection + assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode." + + # Initialize learning rate settings + self.initial_lr = initial_lr + self.current_lr = initial_lr + self.min_lr = min_lr + self.mode = mode + self.num_epochs = num_epochs + self.iters_per_epoch = iters_per_epoch + self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch + self.slow_start_iters = slow_start_epochs * iters_per_epoch + self.slow_start_lr = slow_start_lr + self.multiplier = multiplier + self.lr_step = lr_step + self.lr_milestones = lr_milestones + self.step_multiplier = step_multiplier + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.decay_factor = decay_factor + self.decay_steps = decay_epochs * iters_per_epoch + self.staircase = staircase + + print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.") + + def update_lr(self, optimizer, iteration, epoch): + """Update the learning rate based on the current iteration and epoch.""" + current_iter = epoch * self.iters_per_epoch + iteration + + # During slow start, linearly increase the learning rate + if current_iter <= self.slow_start_iters: + lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters) + else: + # After slow start, calculate learning rate based on the selected mode + lr = self._calculate_lr(current_iter - self.slow_start_iters) + + # Ensure learning rate does not fall below the minimum limit + self.current_lr = max(lr, self.min_lr) + self._apply_lr(optimizer, self.current_lr) + + def _calculate_lr(self, adjusted_iter): + """Calculate the learning rate based on the selected scheduling mode.""" + if self.mode == 'cosine': + # Cosine annealing schedule + return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations)) + elif self.mode == 'poly': + # Polynomial decay schedule + return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9 + elif self.mode == 'HTD': + # Hyperbolic tangent decay schedule + ratio = adjusted_iter / self.total_iterations + return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio)) + elif self.mode == 'step': + # Step decay schedule + return self._step_lr(adjusted_iter) + elif self.mode == 'exponential': + # Exponential decay schedule + power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps + return self.initial_lr * (self.decay_factor ** power) + else: + raise NotImplementedError("Unknown learning rate mode.") + + def _step_lr(self, adjusted_iter): + """Calculate the learning rate for the 'step' mode.""" + epoch = adjusted_iter // self.iters_per_epoch + # Count how many milestones or steps have passed + if self.lr_milestones: + num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone]) + else: + num_steps = epoch // self.lr_step + return self.initial_lr * (self.step_multiplier ** num_steps) + + def _apply_lr(self, optimizer, lr): + """Apply the calculated learning rate to the optimizer.""" + for i, param_group in enumerate(optimizer.param_groups): + # Apply multiplier to parameter groups beyond the first one + param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0) + + +def adjust_hyperparameters(args): + """Adjust the learning rate and momentum based on the batch size.""" + print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}') + # Set standard batch size for scaling + standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError + # Scale momentum and learning rate + args.momentum = args.momentum ** (args.batch_size / standard_batch_size) + args.lr *= (args.batch_size / standard_batch_size) + print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}') + return args + + +def separate_parameters(model, weight_decay_for_norm=0): + """Separate the model parameters into two groups: regular parameters and norm-based parameters.""" + regular_params, norm_params = [], [] + for name, param in model.named_parameters(): + if param.requires_grad: + # Parameters related to normalization and biases are treated separately + if 'norm' in name or 'bias' in name: + norm_params.append(param) + else: + regular_params.append(param) + # Return parameter groups with corresponding weight decay for norm parameters + return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}] diff --git a/EdgeFLite/helpers/preloader_module.py b/EdgeFLite/helpers/preloader_module.py new file mode 100644 index 0000000..e771a80 --- /dev/null +++ b/EdgeFLite/helpers/preloader_module.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch + +class DataPrefetcher: + def __init__(self, dataloader): + # Initialize with the dataloader and create an iterator + self.dataloader = iter(dataloader) + # Create a CUDA stream for asynchronous data transfer + self.cuda_stream = torch.cuda.Stream() + # Load the next batch of data + self._load_next_batch() + + def _load_next_batch(self): + try: + # Fetch the next batch from the dataloader iterator + self.batch_input, self.batch_target = next(self.dataloader) + except StopIteration: + # If no more data, set inputs and targets to None + self.batch_input, self.batch_target = None, None + return + + # Transfer data to GPU asynchronously using the created CUDA stream + with torch.cuda.stream(self.cuda_stream): + self.batch_input = self.batch_input.cuda(non_blocking=True) + self.batch_target = self.batch_target.cuda(non_blocking=True) + + def get_next_batch(self): + # Synchronize the current stream with the prefetching stream to ensure data is ready + torch.cuda.current_stream().wait_stream(self.cuda_stream) + + # Return the preloaded batch of input and target data + current_input, current_target = self.batch_input, self.batch_target + + # Preload the next batch in the background while the current batch is processed + self._load_next_batch() + + return current_input, current_target diff --git a/EdgeFLite/helpers/report_summary.py b/EdgeFLite/helpers/report_summary.py new file mode 100644 index 0000000..e186210 --- /dev/null +++ b/EdgeFLite/helpers/report_summary.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +__all__ = ['model_summary'] + +import torch +import torch.nn as nn +import numpy as np +import os +import json +from collections import OrderedDict + +# Format FLOPs value with appropriate unit (T, G, M, K) +def format_flops(flops): + units = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] + for scale, suffix in units: + if flops >= scale: + return f"{flops / scale:.1f}{suffix}" + return f"{flops:.1f}" + +# Calculate the number of trainable or non-trainable parameters +def calculate_grad_params(param_count, param): + if param.requires_grad: + return param_count, 0 + else: + return 0, param_count + +# Compute FLOPs and parameters for a convolutional layer +def compute_conv_flops(layer, input, output): + oh, ow = output.shape[-2:] # Output height and width + kh, kw = layer.kernel_size # Kernel height and width + ic, oc = layer.in_channels, layer.out_channels # Input/output channels + groups = layer.groups # Number of groups for grouped convolution + + total_trainable = 0 + total_non_trainable = 0 + flops = 0 + + # Compute parameters and FLOPs for the weight + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // groups) + + # Compute parameters and FLOPs for the bias + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + flops += oh * ow * (oc // groups) + + return total_trainable, total_non_trainable, flops + +# Compute FLOPs and parameters for normalization layers (BatchNorm, GroupNorm) +def compute_norm_flops(layer, input, output): + total_trainable = 0 + total_non_trainable = 0 + + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + + if hasattr(layer, 'running_mean'): + total_non_trainable += np.prod(layer.running_mean.shape) + + if hasattr(layer, 'running_var'): + total_non_trainable += np.prod(layer.running_var.shape) + + # FLOPs for normalization operations + flops = np.prod(input[0].shape) + if layer.affine: + flops *= 2 + + return total_trainable, total_non_trainable, flops + +# Compute FLOPs and parameters for linear (fully connected) layers +def compute_linear_flops(layer, input, output): + ic, oc = layer.in_features, layer.out_features # Input/output features + total_trainable = 0 + total_non_trainable = 0 + flops = 0 + + # Compute parameters and FLOPs for the weight + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + flops += (2 * ic - 1) * oc + + # Compute parameters and FLOPs for the bias + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + flops += oc + + return total_trainable, total_non_trainable, flops + +# Model summary function: calculates the total parameters and FLOPs for a model +@torch.no_grad() +def model_summary(model, input_data, target_data=None, is_coremodel=True, return_data=False): + model.eval() + + summary_info = OrderedDict() + hooks = [] + + # Hook function to register layer and compute its parameters/FLOPs + def register_layer_hook(layer): + def hook(layer, input, output): + layer_name = f"{layer.__class__.__name__}-{len(summary_info) + 1}" + summary_info[layer_name] = OrderedDict() + summary_info[layer_name]['input_shape'] = list(input[0].shape) + summary_info[layer_name]['output_shape'] = list(output.shape) if not isinstance(output, (list, tuple)) else [list(o.shape) for o in output] + + if isinstance(layer, nn.Conv2d): + trainable, non_trainable, flops = compute_conv_flops(layer, input, output) + elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)): + trainable, non_trainable, flops = compute_norm_flops(layer, input, output) + elif isinstance(layer, nn.Linear): + trainable, non_trainable, flops = compute_linear_flops(layer, input, output) + else: + trainable, non_trainable, flops = 0, 0, 0 + + summary_info[layer_name]['trainable_params'] = trainable + summary_info[layer_name]['non_trainable_params'] = non_trainable + summary_info[layer_name]['total_params'] = trainable + non_trainable + summary_info[layer_name]['flops'] = flops + + if not isinstance(layer, (nn.Sequential, nn.ModuleList, nn.Identity)): + hooks.append(layer.register_forward_hook(hook)) + + model.apply(register_layer_hook) + + if is_coremodel: + model(input_data, target=target_data, mode='summary') + else: + model(input_data) + + for hook in hooks: + hook.remove() + + total_params, trainable_params, total_flops = 0, 0, 0 + for layer_name, layer_info in summary_info.items(): + total_params += layer_info['total_params'] + trainable_params += layer_info['trainable_params'] + total_flops += layer_info['flops'] + + param_size_mb = total_params * 4 / (1024 ** 2) + print(f"Total parameters: {total_params:,} ({format_flops(total_params)})") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Non-trainable parameters: {total_params - trainable_params:,}") + print(f"Total FLOPs: {total_flops:,} ({format_flops(total_flops)})") + print(f"Model size: {param_size_mb:.2f} MB") + + if return_data: + return total_params, total_flops + +# Example usage with a convolutional layer +if __name__ == '__main__': + conv_layer = nn.Conv2d(50, 10, 3, padding=1, groups=5, bias=True) + model_summary(conv_layer, torch.rand((1, 50, 10, 10)), target_data=torch.ones(1, dtype=torch.long), is_coremodel=False) + + for name, param in conv_layer.named_parameters(): + print(f"{name}: {param.size()}") + +# Save the model's summary details as a JSON file +def save_model_as_json(args, model_content): + """Save the model's details to a JSON file.""" + os.makedirs(args.model_dir, exist_ok=True) + filename = os.path.join(args.model_dir, f"model_{args.split_factor}.txt") + + with open(filename, 'w') as f: + f.write(str(model_content)) diff --git a/EdgeFLite/helpers/smoothing_labels.py b/EdgeFLite/helpers/smoothing_labels.py new file mode 100644 index 0000000..965f358 --- /dev/null +++ b/EdgeFLite/helpers/smoothing_labels.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Define the SmoothEntropyLoss class, which inherits from nn.Module +class SmoothEntropyLoss(nn.Module): + def __init__(self, smoothing=0.1, reduction='mean'): + # Initialize the parent class (nn.Module) and set the smoothing factor and reduction method + super(SmoothEntropyLoss, self).__init__() + self.smoothing = smoothing # Label smoothing factor + self.reduction_method = reduction # Reduction method to apply to the loss + + def forward(self, predictions, targets): + # Ensure that the batch sizes of predictions and targets match + if predictions.shape[0] != targets.shape[0]: + raise ValueError(f"Batch size of predictions ({predictions.shape[0]}) does not match targets ({targets.shape[0]}).") + + # Ensure that the predictions tensor has at least 2 dimensions (batch_size x num_classes) + if predictions.dim() < 2: + raise ValueError(f"Predictions should have at least 2 dimensions, got {predictions.dim()}.") + + # Get the number of classes from the last dimension of predictions (num_classes) + num_classes = predictions.size(-1) + + # Convert targets (class indices) to one-hot encoded format + target_one_hot = F.one_hot(targets, num_classes=num_classes).type_as(predictions) + + # Apply label smoothing: smooth the one-hot encoded targets by distributing some probability mass across all classes + smooth_targets = target_one_hot * (1.0 - self.smoothing) + (self.smoothing / num_classes) + + # Compute the log probabilities of predictions using softmax (log-softmax for numerical stability) + log_probabilities = F.log_softmax(predictions, dim=-1) + + # Compute the per-sample loss by multiplying log probabilities with the smoothed targets and summing across classes + loss_per_sample = -torch.sum(log_probabilities * smooth_targets, dim=-1) + + # Apply the specified reduction method to the computed loss + if self.reduction_method == 'none': + return loss_per_sample # Return the unreduced loss for each sample + elif self.reduction_method == 'sum': + return torch.sum(loss_per_sample) # Return the sum of the losses over all samples + elif self.reduction_method == 'mean': + return torch.mean(loss_per_sample) # Return the mean loss over all samples + else: + raise ValueError(f"Invalid reduction option: {self.reduction_method}. Expected 'none', 'sum', or 'mean'.") diff --git a/EdgeFLite/info_map.csv b/EdgeFLite/info_map.csv new file mode 100644 index 0000000..6ecea73 --- /dev/null +++ b/EdgeFLite/info_map.csv @@ -0,0 +1,68 @@ +import pandas as pd +import os +from glob import glob +from PIL import Image +import torch +from sklearn.model_selection import train_test_split +import pickle +from torch.utils.data import Dataset, DataLoader +from torch import nn +from torchvision import apply_transformations + +# Loading the info_mapdata for the skin_dataset dataset +info_mapdata = pd.read_csv('dataset_hub/skin_dataset/data/skin_info_map.csv') +print(info_mapdata.head()) + +# Mapping lesion abbreviations to their full names +lesion_labels = { + 'nv': 'Melanocytic nevi', + 'mel': 'Melanoma', + 'bkl': 'Benign keratosis-like lesions', + 'bcc': 'Basal cell carcinoma', + 'akiec': 'Actinic keratoses', + 'vasc': 'Vascular lesions', + 'df': 'Dermatofibroma' +} + +# Combine images from both dataset parts into one dictionary +image_paths = {os.path.splitext(os.path.basename(img))[0]: img + for img in glob(os.path.join("dataset_hub/skin_dataset/data", '*', '*.jpg'))} + +# Mapping the image paths and cell types to the DataFrame +info_mapdata['image_path'] = info_mapdata['image_id'].map(image_paths.get) +info_mapdata['cell_type'] = info_mapdata['dx'].map(lesion_labels.get) +info_mapdata['label'] = pd.Categorical(info_mapdata['cell_type']).workspaces + +# Display the count of each cell type and their enworkspaced labels +print(info_mapdata['cell_type'].value_counts()) +print(info_mapdata['label'].value_counts()) + +# Custom Dataset class for PyTorch +class SkinDataset(Dataset): + def __init__(self, dataframe, apply_transformation=None): + self.dataframe = dataframe + self.apply_transformation = apply_transformation + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + img = Image.open(self.dataframe.loc[idx, 'image_path']).resize((64, 64)) + label = torch.tensor(self.dataframe.loc[idx, 'label'], dtype=torch.long) + + if self.apply_transformation: + img = self.apply_transformation(img) + + return img, label + +# Splitting the data into train and test sets +train_data, test_data = train_test_split(info_mapdata, test_size=0.2, random_state=42) +train_data = train_data.reset_index(drop=True) +test_data = test_data.reset_index(drop=True) + +# Save the train and test data to pickle files +with open("skin_dataset_train.pkl", "wb") as train_file: + pickle.dump(train_data, train_file) + +with open("skin_dataset_test.pkl", "wb") as test_file: + pickle.dump(test_data, test_file) diff --git a/EdgeFLite/process_data.py b/EdgeFLite/process_data.py new file mode 100644 index 0000000..155e18a --- /dev/null +++ b/EdgeFLite/process_data.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import glob +import numpy as np + +# Define paths to the training and testing datasets +train_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/train_images' +test_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/test_images' + +def list_image_files_by_class(directory): + """ + Returns a list of image file paths and their corresponding class indices. + + Args: + directory (str): The path to the directory containing class folders. + + Returns: + list: A list of image file paths and their class indices. + """ + # Get the sorted list of class labels (folder names) + class_labels = sorted(os.listdir(directory)) + # Create a mapping from class names to indices + class_to_idx = {class_name: idx for idx, class_name in enumerate(class_labels)} + + image_dataset = [] # Initialize an empty list to store image data + + # Iterate through each class + for class_name in class_labels: + class_folder = os.path.join(directory, class_name) # Path to the class folder + # Find all JPG images in the class folder and its subfolders + image_files = glob.glob(os.path.join(class_folder, '**', '*.jpg'), recursive=True) + + # Append image file paths and their class indices to the dataset + for image_file in image_files: + image_dataset.append([image_file, class_to_idx[class_name]]) + + return image_dataset + +if __name__ == "__main__": + # Retrieve and print the number of files in the training and testing datasets + train_images = list_image_files_by_class(train_data_path) + test_images = list_image_files_by_class(test_data_path) + + print(f"Training dataset size: {len(train_images)}") # Output the size of the training dataset + print(f"Testing dataset size: {len(test_images)}") # Output the size of the testing dataset diff --git a/EdgeFLite/resnet_federated.py b/EdgeFLite/resnet_federated.py new file mode 100644 index 0000000..4459937 --- /dev/null +++ b/EdgeFLite/resnet_federated.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import argparse +import os +import torch +from dataset import factory +from params import train_params +from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10 +from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100 +from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset +from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase +from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt +from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer +from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer +from params.train_params import save_hp_to_json +from config import HOME +from tensorboardX import SummaryWriter + +# Set CUDA device to be used for training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +device = torch.device("cuda:0") + +# Initialize TensorBoard writers for logging +def initialize_writers(args): + log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory + return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging + +# Initialize dataset and data loaders +def initialize_dataset(args, data_split_factor): + # Fetch training data and sampler based on various input parameters + train_data_local_dict, train_sampler = factory.obtain_data_loader( + args.data, + split_factor=data_split_factor, + batch_size=args.batch_size, + crop_size=args.crop_size, + dataset=args.dataset, + split="train", # Split data for training + is_decentralized=args.is_decentralized, + is_autoaugment=args.is_autoaugment, + randaa=args.randaa, + is_cutout=args.is_cutout, + erase_p=args.erase_p, + num_workers=args.workers, + is_fed=args.is_fed, + num_clusters=args.num_clusters, + cifar10_non_iid=args.cifar10_non_iid, + cifar100_non_iid=args.cifar100_non_iid + ) + + # Fetch global test data + test_data_global = factory.obtain_data_loader( + args.data, + batch_size=args.eval_batch_size, + crop_size=args.crop_size, + dataset=args.dataset, + split="val", # Split data for validation + num_workers=args.workers, + cifar10_non_iid=args.cifar10_non_iid, + cifar100_non_iid=args.cifar100_non_iid + ) + return train_data_local_dict, test_data_global # Return both train and test data loaders + +# Setup models based on the dataset +def setup_models(args): + if args.dataset == "cifar10": + return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10 + elif args.dataset == "cifar100": + return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100 + elif args.dataset == "skin_dataset": + return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset + elif args.dataset == "pill_base": + return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset + +# Initialize trainers for each client in the federated learning setup +def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict): + client_trainers = [] + # Initialize a trainer for each client + for i in range(client_number): + trainer = GKTTrainer( + client_idx=i, + train_data_local_dict=train_data_local_dict, + test_data_local_dict=test_data_local_dict, + device=device, + model_client=model_client, + args=args + ) + client_trainers.append(trainer) # Add client trainer to the list + return client_trainers + +# Main function to initialize and run the federated learning process +def main(args): + args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid + + # Save hyperparameters if not in summary or evaluation mode + if not args.is_summary and not args.evaluate: + save_hp_to_json(args) + + # Initialize the TensorBoard writer for logging + val_writer = initialize_writers(args) + data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed + + print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'") + + # Load dataset and initialize data loaders + train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor) + + # Setup models for the clients and server + data_loader, (model_client, model_server) = setup_models(args) + client_number = args.num_clusters * args.split_factor # Calculate the number of clients + + # Load data for federated learning + train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader( + args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size + ) + + dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num] + + print("Server and clients initialized.") + round_idx = 0 # Initialize the training round index + + # Initialize client trainers and server trainer + client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict) + server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer) + + # Start federated training rounds + for current_round in range(args.num_rounds): + # For each client, perform local training and send results to the server + for client_idx in range(client_number): + extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train() + print(f"Client {client_idx} finished training.") + server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels) + + # Check if the server has received all clients' results + if server_trainer.check_whether_all_receive(): + print("All clients' results received by server.") + server_trainer.train(round_idx) # Server performs training using the aggregated results + round_idx += 1 + + # Send global model updates back to clients + for client_idx in range(client_number): + global_logits = server_trainer.get_global_logits(client_idx) + client_trainers[client_idx].update_large_model_logits(global_logits) + print("Server sent updated logits back to clients.") + +# Entry point of the script +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = train_params.add_parser_params(parser) + + # Ensure that federated learning mode is enabled + assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1." + + # Create the model directory if it does not exist + os.makedirs(args.model_dir, exist_ok=True) + + print(args) # Print the parsed arguments for verification + main(args) # Start the main process diff --git a/EdgeFLite/run_federated.py b/EdgeFLite/run_federated.py new file mode 100644 index 0000000..7e83319 --- /dev/null +++ b/EdgeFLite/run_federated.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import torch.decentralized as dist +import torch.multiprocessing as mp +import torch.cuda.amp as amp +from torch.backends import cudnn +from tensorboardX import SummaryWriter +import warnings +import argparse +import os +import numpy as np +from tqdm import tqdm +from dataset import factory +from model import coremodel +from utils import metric, label_smoothing, lr_scheduler, prefetch +from params.train_params import save_hp_to_json +from params import train_params + +# Global variable to track the best accuracy +best_accuracy = 0 + +def calculate_average(values): + """Calculate the average of a list of values""" + return sum(values) / len(values) + +def initialize_processes(rank, world_size, args): + """ + Initialize decentralized processes. + This function is used to set up distributed training across multiple GPUs. + """ + ngpus = torch.cuda.device_count() + args.ngpus = ngpus + args.is_decentralized = world_size > 1 + + if args.multiprocessing_decentralized: + # If running decentralized with multiple GPUs, spawn processes for each GPU + mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args)) + else: + print(f"INFO:PyTorch: Using {ngpus} GPUs") + # If single GPU, start the training worker directly + train_single_worker(args.gpu, ngpus, args) + +def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None): + """ + Perform training for a single client model in the federated learning setup. + This method will train the model for a given number of epochs. + """ + model.train() # Set model to training mode + for epoch in range(epochs): + # Prefetch data to improve efficiency + prefetcher = prefetch.data_prefetcher(dataloader) + images, targets = prefetcher.next() + step = 0 + + while images is not None: + # Update the learning rate using the scheduler + scheduler(optimizer, step, current_round) + optimizer.zero_grad() # Clear the gradients + + # Enable mixed precision training to optimize memory and computation speed + with amp.autocast(enabled=args.is_amp): + outputs, ce_loss, cot_loss = model(images, target=targets, mode='train') + + # Combine losses and normalize by accumulation steps + loss = (ce_loss + cot_loss) / args.accumulation_steps + loss.backward() # Backpropagate the gradients + + # Perform optimization step after enough accumulation + if step % args.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() # Clear gradients after the step + + images, targets = prefetcher.next() # Get the next batch of images and targets + step += 1 + + return loss.item() # Return the final loss value + +def combine_model_parameters(global_model, client_models): + """ + Aggregate the weights of multiple client models to update the global model. + This is the core of the Federated Averaging (FedAvg) algorithm. + """ + global_state = global_model.state_dict() + for key in global_state.keys(): + # Average the weights of the corresponding layers from all client models + global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0) + + # Load the averaged weights into the global model + global_model.load_state_dict(global_state) + # Update the client models with the new global model weights + for client in client_models: + client.load_state_dict(global_model.state_dict()) + +def validate_model(validation_loader, model, args): + """ + Perform model validation on the validation dataset. + Calculate and return the average accuracy across the dataset. + """ + model.eval() # Set the model to evaluation mode + accuracy_values = [] + + with torch.no_grad(): + for images, targets in validation_loader: + if args.gpu is not None: + images, targets = images.cuda(args.gpu), targets.cuda(args.gpu) + + # Use mixed precision for inference + with amp.autocast(enabled=args.is_amp): + ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val') + + # Calculate the top-1 accuracy for the current batch + avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,)) + accuracy_values.append(avg_acc1) + + return calculate_average(accuracy_values) # Return the average accuracy + +def train_single_worker(gpu, ngpus, args): + """ + Training worker function that runs on a single GPU. + This function handles the entire federated learning workflow for the assigned GPU. + """ + global best_accuracy + args.gpu = gpu + cudnn.performance_test = True # Enable performance optimization for CuDNN + + # Optionally, resume from a checkpoint if provided + if args.resume: + checkpoint = torch.load(args.resume) + args.start_round = checkpoint['round'] + best_accuracy = checkpoint['best_acc1'] + + # Initialize global and client models + model = coremodel.coremodel(args).cuda() + client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)] + optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models] + + # Training and validation loop + for round_num in range(args.start_round, args.num_rounds): + # Perform training for each client model + for client_num in range(args.num_clients): + client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader) + + # Aggregate client models to update the global model + combine_model_parameters(model, client_models) + + # Validate the updated global model and track the best accuracy + validation_accuracy = validate_model(args.val_loader, model, args) + best_accuracy = max(best_accuracy, validation_accuracy) + print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}") + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser(description='FedAvg decentralized Training') + args = train_params.add_parser_params(parser) + initialize_processes(0, args.world_size, args) # Initialize distributed training diff --git a/EdgeFLite/run_local.py b/EdgeFLite/run_local.py new file mode 100644 index 0000000..460173c --- /dev/null +++ b/EdgeFLite/run_local.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import argparse +import warnings +import setproctitle +from torch import nn, decentralized # Used for decentralized training +from torch.backends import cudnn # Optimizes performance for convolutional networks +from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard +import torch.cuda.amp as amp # For mixed precision training +from config import * # Custom configuration module +from params import train_params # Training parameters +from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions +from model import coremodel # Core model implementation +from dataset import factory # Dataset and data loader factory +from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON + +# Global variable to store the best accuracy obtained during training +best_acc1 = 0 + +def main(args): + # Warn if a specific GPU is chosen, as this will disable data parallelism + if args.gpu is not None: + warnings.warn("Selecting a specific GPU will disable data parallelism.") + + # Adjust loop factor based on specific training configurations + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + # Check if decentralized training is needed + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized + + # Get the number of available GPUs on the machine + num_gpus = torch.cuda.device_count() + args.ngpus_per_node = num_gpus + print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}") + + # If multiprocessing is needed for decentralized training + if args.multiprocessing_decentralized: + # Adjust world size to account for multiple GPUs + args.world_size *= num_gpus + # Spawn multiple processes for each GPU + torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args)) + else: + # If using a single GPU + print("INFO:PyTorch: Using GPU 0 for single GPU training") + args.gpu = 0 + # Call main worker for single GPU + execute_worker_process(args.gpu, num_gpus, args) + +def execute_worker_process(gpu, num_gpus, args): + global best_acc1 + args.gpu = gpu + # Set the directory where models will be saved + args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid)) + + # Initialize the decentralized training process group if needed + if args.is_decentralized: + print("INFO:PyTorch: Initializing process group for decentralized training.") + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_decentralized: + args.rank = args.rank * num_gpus + gpu + decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) + + # Set the GPU to be used for training or evaluation + if args.gpu is not None: + print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})") + + # Set process title for better identification in system process monitors + setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}") + + # Initialize a SummaryWriter for TensorBoard logging + val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) + + # Use label smoothing if enabled, otherwise use standard cross-entropy loss + criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() + + # Instantiate the model + model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters") + + # If summary is requested, print model and exit + if args.is_summary: + print(model) + return + + # Save model configuration and hyperparameters + summary.save_model_to_json(args, model) + + # Convert BatchNorm layers to synchronized BatchNorm for decentralized training + if args.is_decentralized and args.world_size > 1 and args.is_syncbn: + print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm") + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # Set up the model for GPU-based training + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs + args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers + # Use decentralized data parallel model + model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + else: + # Use standard DataParallel for multi-GPU training + model = nn.DataParallel(model).cuda() + + # Create the optimizer + optimizer = create_optimizer(args, model) + # Set up the gradient scaler for mixed precision training, if enabled + scaler = amp.GradScaler() if args.is_amp else None + + # If resuming from a checkpoint, load model and optimizer state + if args.resume: + load_checkpoint(args, model, optimizer, scaler) + + cudnn.performance_test = True # Enable cuDNN performance optimizations + + # Set up data loader parameters + data_loader_params = { + 'split_factor': args.loop_factor if args.is_diff_data_train else 1, + 'batch_size': args.batch_size, + 'crop_size': args.crop_size, + 'dataset': args.dataset, + 'is_decentralized': args.is_decentralized, + 'num_workers': args.workers, + 'randaa': args.randaa, + 'is_autoaugment': args.is_autoaugment, + 'is_cutout': args.is_cutout, + 'erase_p': args.erase_p, + } + + # Get the training and validation data loaders + train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params) + val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers) + + # Set up the learning rate scheduler + scheduler = lr_scheduler.create_scheduler(args, len(train_loader)) + + # If evaluating, run the validation function and exit + if args.evaluate: + validate(val_loader, model, args) + return + + # Begin training and evaluation + train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus) + +# Function to create the optimizer +def create_optimizer(args, model): + param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model) + # Select the optimizer based on input arguments + if args.optimizer == 'SGD': + return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov) + elif args.optimizer == 'AdamW': + return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay) + elif args.optimizer == 'RMSprop': + return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay) + else: + # Raise error if unsupported optimizer is selected + raise NotImplementedError(f"Optimizer {args.optimizer} not implemented") + +# Function to load a checkpoint and resume training +def load_checkpoint(args, model, optimizer, scaler): + if os.path.isfile(args.resume): + print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'") + loc = f'cuda:{args.gpu}' if args.gpu is not None else None + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + global best_acc1 + best_acc1 = checkpoint['best_acc1'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if "scaler" in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}") + else: + print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'") + +# Function to train and evaluate the model over multiple epochs +def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus): + for epoch in range(args.start_epoch, args.epochs + 1): + if args.is_decentralized: + train_sampler.set_epoch(epoch) + + train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args) + + # Evaluate the model every 'eval_per_epoch' epochs + if (epoch + 1) % args.eval_per_epoch == 0: + acc_all = validate(val_loader, model, args) + global best_acc1 + is_best = acc_all[0] > best_acc1 # Track the best accuracy + best_acc1 = max(acc_all[0], best_acc1) + # Save the model checkpoint + save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best) + +# Function to perform one training epoch +def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args): + metric_storage = create_metric_storage(args.loop_factor) + model.train() # Set the model to training mode + data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency + images, target = data_loader.next() + + optimizer.zero_grad() # Reset gradients + while images is not None: + # Adjust the learning rate based on the scheduler + scheduler(optimizer, epoch) + + # Perform forward pass with mixed precision if enabled + if args.is_amp: + with amp.autocast(): + ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) + else: + ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) + + # Calculate total loss and normalize + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch) + + # Perform backward pass and update gradients with mixed precision if enabled + if args.is_amp: + scaler.scale(total_loss).backward() + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + images, target = data_loader.next() # Fetch the next batch of data + +# Function to save the model checkpoint +def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best): + ckpt = { + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + } + if args.is_amp: + ckpt['scaler'] = scaler.state_dict() + metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar") + +# Function to validate the model on the validation dataset +def validate(val_loader, model, args): + metric_storage = create_metric_storage(args.loop_factor) + model.eval() # Set the model to evaluation mode + + with torch.no_grad(): + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # Perform forward pass with mixed precision if enabled + if args.is_amp: + with amp.autocast(): + ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') + else: + ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') + + batch_size = images.size(0) + acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) + + metric_storage.update(acc1, acc5, ce_loss, batch_size) + + return metric_storage.results() + +# Helper function to create a storage for metrics during training and validation +def create_metric_storage(loop_factor): + # Initialize metrics for accuracy and other performance metrics + top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)] + avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f') + return metric.ProgressMeter(len(top1_all), top1_all, avg_top1) + +# Main entry point for the script +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Centralized Training') + args = train_params.add_parser_params(parser) # Add parameters to the argument parser + assert args.is_fed == 0, "Centralized training requires args.is_fed to be False" + os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist + main(args) # Call the main function diff --git a/EdgeFLite/run_prox.py b/EdgeFLite/run_prox.py new file mode 100644 index 0000000..e050816 --- /dev/null +++ b/EdgeFLite/run_prox.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import warnings +import torch +import torch.cuda.amp as autocast +from torch import nn +from torch.backends import cudnn +from tensorboardX import SummaryWriter +from config import * +from params import train_settings +from utils import label_smooth, metrics, scheduler, prefetch_loader +from model import net_splitter +from dataset import data_factory +import numpy as np +from tqdm import tqdm +from params.train_settings import save_hyperparams_to_json + +# Set the visible GPU to use for training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Variable to store the best accuracy achieved during training +best_accuracy = 0 + + +# Helper function to compute the average of a list +def compute_average(lst): + return sum(lst) / len(lst) + + +# Main function to initialize the training process +def main(args): + if args.gpu_index is not None: + # Warn if a specific GPU is selected, disabling data parallelism + warnings.warn("Specific GPU chosen, disabling data parallelism.") + + # Adjust loop factor based on training setup + args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor + # Determine if decentralized training is required + args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized + num_gpus = torch.cuda.device_count() + args.num_gpus = num_gpus + + # If decentralized multiprocessing is enabled, spawn multiple processes + if args.multiprocessing_decentralized: + args.world_size = num_gpus * args.world_size + torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args)) + else: + # Otherwise, proceed with single-GPU training + print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.") + args.gpu_index = 0 + worker_process(args.gpu_index, num_gpus, args) + + +# Client-side training function for federated learning updates +def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None): + client_model.train() + + for epoch in range(epochs): + # Prefetch data for training + loader = prefetch_loader.DataPrefetcher(train_loader) + images, targets = loader.next() + batch_idx = 0 + opt.zero_grad() + + while images is not None: + # Apply learning rate scheduling + sched(opt, batch_idx) + + # Use automatic mixed precision if enabled + if args.amp_enabled: + with autocast.autocast(): + ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', + epoch=epoch) + else: + ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', + epoch=epoch) + + # Compute accuracy for top-1 predictions + batch_size = images.size(0) + for j in range(args.loop_factor): + top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,)) + + # Compute the proximal term for FedProx loss + prox_term = sum((param - global_param).norm(2) for param, global_param in + zip(client_model.parameters(), global_model.parameters())) + # Compute the total loss (cross-entropy + contrastive loss + proximal term) + total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term + + # Backward pass with mixed precision scaling if enabled + if args.amp_enabled: + scaler.scale(total_loss).backward() + if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): + scaler.step(opt) + scaler.update() + opt.zero_grad() + else: + total_loss.backward() + if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): + opt.step() + opt.zero_grad() + + images, targets = loader.next() + + return total_loss.item() + + +# Function to aggregate model weights from clients on the server +def server_compute_average_weights(global_model, client_models): + global_state_dict = global_model.state_dict() + # Average weights across all clients + for key in global_state_dict.keys(): + global_state_dict[key] = torch.stack( + [client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0) + global_model.load_state_dict(global_state_dict) + + # Update clients with the averaged global model + for model in client_models: + model.load_state_dict(global_model.state_dict()) + + +# Function to validate the model on the validation set +def validate_model(val_loader, model, args): + model.eval() + acc1_list, acc5_list, loss_ce_list = [], [], [] + + # Perform validation without gradient calculation + with torch.no_grad(): + for images, targets in val_loader: + if args.gpu_index is not None: + images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index, + non_blocking=True) + + if args.amp_enabled: + with autocast.autocast(): + ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') + else: + ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') + + for j in range(args.loop_factor): + acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5)) + + avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5)) + acc1_list.append(avg_acc1) + acc5_list.append(avg_acc5) + loss_ce_list.append(loss_ce) + + return compute_average(loss_ce_list), compute_average(acc1_list) + + +# Function to handle the worker process for training on a specific GPU +def worker_process(gpu_index, num_gpus, args): + global best_accuracy + args.gpu_index = gpu_index + args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id)) + + # Create summary writer for validation if not using decentralized training + if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0): + val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation')) + + # Set the loss function based on the label smoothing option + criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss() + # Initialize the global model and client models + global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) + client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in + range(args.num_clients)] + + # Save hyperparameters to JSON if required + if args.save_summary: + save_hyperparams_to_json(args) + return + + # Move models to GPU + global_model = global_model.cuda() + for model in client_models: + model.cuda() + model.load_state_dict(global_model.state_dict()) + + # Create optimizers for each client + opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov=args.use_nesterov) for client in client_models] + + # Initialize gradient scaler if AMP is enabled + scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None + cudnn.performance_test = True + + # Resume training from checkpoint if specified + if args.resume_training: + if os.path.isfile(args.resume_checkpoint): + checkpoint = torch.load(args.resume_checkpoint, + map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None) + args.start_round = checkpoint['round'] + best_accuracy = checkpoint['best_acc1'] + global_model.load_state_dict(checkpoint['state_dict']) + for opt in opt_list: + opt.load_state_dict(checkpoint['optimizer']) + if "scaler" in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + for client_model in client_models: + client_model.load_state_dict(global_model.state_dict()) + else: + args.start_round = 0 + else: + args.start_round = 0 + + # Load training and validation data + train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor, + dataset_name=args.dataset_name, split="train", + num_workers=args.num_workers, decentralized=args.decentralized_training) + val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor, + dataset_name=args.dataset_name, split="val", num_workers=args.num_workers) + + # Federated learning rounds + for round_num in range(args.start_round, args.num_rounds + 1): + if args.fixed_cluster: + # Select clients from fixed clusters for each round + selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients] + for i in tqdm(range(args.num_clients)): + selected_clients = np.arange(start=selected_clusters[i] * args.split_factor, + stop=(selected_clusters[i] + 1) * args.split_factor) + for client in selected_clients: + loss = client_train diff --git a/EdgeFLite/run_splitfed.py b/EdgeFLite/run_splitfed.py new file mode 100644 index 0000000..083c316 --- /dev/null +++ b/EdgeFLite/run_splitfed.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import numpy as np +import argparse +import warnings +from tqdm import tqdm +from tensorboardX import SummaryWriter +from dataset import factory +from config import * +from model import coremodelsl +from utils import label_smoothing, norm, metric, lr_scheduler, prefetch +from params import train_params +from params.train_params import save_hp_to_json + +# Set the visible GPU devices for the training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Global best accuracy to track the performance +best_acc1 = 0 # Global best accuracy + +def average(values): + """Calculate the average of a list of values.""" + return sum(values) / len(values) + +def combine_model_weights(global_model_client, global_model_server, client_models, server_models): + """ + Aggregate weights from client and server models using the mean method. + This function updates the global model weights by averaging the weights + from all client and server models. + """ + # Get the state dictionaries (weights) for both client and server models + client_state_dict = global_model_client.state_dict() + server_state_dict = global_model_server.state_dict() + + # Average the weights across all client models + for key in client_state_dict.keys(): + client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0) + global_model_client.load_state_dict(client_state_dict) + + # Average the weights across all server models + for key in server_state_dict.keys(): + server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0) + global_model_server.load_state_dict(server_state_dict) + + # Load the updated global model weights back into the client models + for model in client_models: + model.load_state_dict(global_model_client.state_dict()) + + # Load the updated global model weights back into the server models + for model in server_models: + model.load_state_dict(global_model_server.state_dict()) + +def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None): + """ + Perform client-side model training for the given number of epochs. + The client model performs the forward pass and sends intermediate outputs + to the server model for further computation. + """ + client_model.train() + server_model.train() + + for epoch in range(epochs): + # Prefetch data to improve data loading speed + prefetcher = prefetch.data_prefetcher(data_loader) + images, target = prefetcher.next() + i = 0 + optimizer_client.zero_grad() + optimizer_server.zero_grad() + + while images is not None: + # Adjust learning rates using the schedulers + scheduler_client(optimizer_client, i, round_num) + scheduler_server(optimizer_server, i, round_num) + i += 1 + + # Forward pass on the client model + outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams) + client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client] + + # Forward pass on the server model and compute losses + ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams) + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + total_loss.backward() + + # Backpropagate the gradients to the client model + for fx, grad in zip(outputs_client, client_fx): + fx.backward(grad.grad) + + # Perform optimization step when the accumulation condition is met + if i % args.iters_to_accumulate == 0 or i == len(data_loader): + optimizer_client.step() + optimizer_server.step() + optimizer_client.zero_grad() + optimizer_server.zero_grad() + + # Fetch the next batch of data + images, target = prefetcher.next() + + return total_loss.item() + +def validate_model(val_loader, client_model, server_model, args, streams=None): + """ + Validate the performance of client and server models. + This function performs forward passes without updating the model weights + and computes validation accuracy and loss. + """ + client_model.eval() + server_model.eval() + + acc1_list, acc5_list, ce_loss_list = [], [], [] + + with torch.no_grad(): + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # Forward pass on the client model + outputs_client = client_model(images, target=target, mode='val') + client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client] + + # Forward pass on the server model + ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val') + + # Calculate accuracy and losses + acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) + acc1_list.append(acc1) + acc5_list.append(acc5) + ce_loss_list.append(ce_loss) + + # Calculate average accuracy and loss over the validation dataset + avg_acc1 = average(acc1_list) + avg_acc5 = average(acc5_list) + avg_ce_loss = average(ce_loss_list) + + return avg_ce_loss, avg_acc1, avg_acc5 + +def main(args): + """ + The main entry point for the federated learning process. + Initializes models, handles multiprocessing setup, and starts training. + """ + if args.gpu is not None: + warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.") + + # Set loop factor based on training configuration + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + ngpus_per_node = torch.cuda.device_count() + args.ngpus_per_node = ngpus_per_node + + if args.multiprocessing_decentralized: + # Spawn a process for each GPU in decentralized setup + args.world_size = ngpus_per_node * args.world_size + torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Use only a single GPU in non-decentralized setup + args.gpu = 0 + execute_worker_process(args.gpu, ngpus_per_node, args) + +def execute_worker_process(gpu, ngpus_per_node, args): + """ + Worker function that handles model initialization, training, and validation. + """ + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print(f"Using GPU {args.gpu} for training.") + + # Create tensorboard writer for logging validation metrics + if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0): + val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) + + # Define loss criterion with label smoothing or cross-entropy + criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() + + # Initialize global client and server models + global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + + # Initialize client and server models for each selected client + client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] + server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] + + # Save hyperparameters to a JSON file + save_hp_to_json(args) + + # Move global models and client/server models to GPU + global_model_client = global_model_client.cuda() + global_model_server = global_model_server.cuda() + for model in client_models + server_models: + model.cuda() + + # Load global model weights into each client and server model + for model in client_models: + model.load_state_dict(global_model_client.state_dict()) + for model in server_models: + model.load_state_dict(global_model_server.state_dict()) + + # Initialize learning rate schedulers for clients and servers + schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] + schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] + + # Start the training and validation loop for the specified number of rounds + for r in range(args.start_round, args.num_rounds + 1): + # Randomly select client indices for training in each round + client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop diff --git a/EdgeFLite/scripts/.DS_Store b/EdgeFLite/scripts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0V4u473vpqBN8#_Y3@Bh2RA|fDnk30*M0AU&Xufw9MF61Uk}0N)wGGdv<-E zTkRC*vzhtgyq}xR%xnrL+J#|kd`=(PLuDKY$2)fRVehuF&Hgm1zMmlXi41VJ*Zke# z?RY%gCn+lhq<|EV0#ZN<{8|C_UfA?GQBev=0V(jU0KX3nPV9w4VthIfVgvwAkPgE- zW(irB!C&F5`GnK34hllFV_>vnY}k%p cqA2Sc*LdCwhs2;GA9SF82B?cn3jDPKUtz@+9smFU literal 0 HcmV?d00001 diff --git a/EdgeFLite/thop/helper_utils.py b/EdgeFLite/thop/helper_utils.py new file mode 100644 index 0000000..b0346b0 --- /dev/null +++ b/EdgeFLite/thop/helper_utils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan +from collections.abc import Iterable + +# Define a function named 'clever_format' that takes two arguments: +# 1. 'nums' - either a single number or a list of numbers to format. +# 2. 'fmt' - an optional string argument specifying the format for the numbers (default is "%.2f", meaning two decimal places). +def clever_format(nums, fmt="%.2f"): + + # Check if the input 'nums' is not an instance of an iterable (like a list or tuple). + # If it is not iterable, convert the single number into a list for uniform processing later. + if not isinstance(nums, Iterable): + nums = [nums] + + # Create an empty list to store the formatted numbers. + formatted_nums = [] + + # Loop through each number in the 'nums' list. + for num in nums: + # Check if the number is greater than 1 trillion (1e12). If so, format it by dividing it by 1 trillion and appending 'T' (for trillions). + if num > 1e12: + formatted_nums.append(fmt % (num / 1e12) + "T") + # If the number is greater than 1 billion (1e9), format it by dividing by 1 billion and appending 'G' (for billions). + elif num > 1e9: + formatted_nums.append(fmt % (num / 1e9) + "G") + # If the number is greater than 1 million (1e6), format it by dividing by 1 million and appending 'M' (for millions). + elif num > 1e6: + formatted_nums.append(fmt % (num / 1e6) + "M") + # If the number is greater than 1 thousand (1e3), format it by dividing by 1 thousand and appending 'K' (for thousands). + elif num > 1e3: + formatted_nums.append(fmt % (num / 1e3) + "K") + # If the number is less than 1 thousand, simply format it using the provided format and append 'B' (for base or basic). + else: + formatted_nums.append(fmt % num + "B") + + # If only one number was passed, return just the formatted string for that number. + # If multiple numbers were passed, return a tuple containing all formatted numbers. + return formatted_nums[0] if len(formatted_nums) == 1 else tuple(formatted_nums) diff --git a/EdgeFLite/thop/hooks_basic.py b/EdgeFLite/thop/hooks_basic.py new file mode 100644 index 0000000..f3acdc4 --- /dev/null +++ b/EdgeFLite/thop/hooks_basic.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import argparse +import logging + +import torch +import torch.nn as nn +from torch.nn.modules.conv import _ConvNd + +multiply_adds = 1 + +def count_parameters(m, x, y): + """Counts the number of parameters in a model.""" + total_params = sum(p.numel() for p in m.parameters()) + m.total_params[0] = torch.DoubleTensor([total_params]) + +def zero_ops(m, x, y): + """Sets total operations to zero.""" + m.total_ops += torch.DoubleTensor([0]) + +def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): + """Counts operations for convolutional layers.""" + x = x[0] + kernel_ops = m.weight[0][0].numel() # Kw x Kh + bias_ops = 1 if m.bias is not None else 0 + total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops) + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): + """Alternative method for counting operations for convolutional layers.""" + x = x[0] + output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() + kernel_ops = m.weight.numel() + (m.bias.numel() if m.bias is not None else 0) + m.total_ops += torch.DoubleTensor([output_size * kernel_ops]) + +def count_bn(m, x, y): + """Counts operations for batch normalization layers.""" + x = x[0] + nelements = x.numel() + if not m.training: + total_ops = 2 * nelements + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_relu(m, x, y): + """Counts operations for ReLU activation.""" + x = x[0] + nelements = x.numel() + m.total_ops += torch.DoubleTensor([nelements]) + +def count_softmax(m, x, y): + """Counts operations for softmax.""" + x = x[0] + batch_size, nfeatures = x.size() + total_ops = batch_size * (2 * nfeatures - 1) + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_avgpool(m, x, y): + """Counts operations for average pooling layers.""" + num_elements = y.numel() + m.total_ops += torch.DoubleTensor([num_elements]) + +def count_adap_avgpool(m, x, y): + """Counts operations for adaptive average pooling layers.""" + kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(list((m.output_size,))).squeeze() + kernel_ops = torch.prod(kernel) + 1 + num_elements = y.numel() + m.total_ops += torch.DoubleTensor([kernel_ops * num_elements]) + +def count_upsample(m, x, y): + """Counts operations for upsample layers.""" + if m.mode not in ("nearest", "linear", "bilinear", "bicubic"): + logging.warning(f"Mode {m.mode} is not implemented yet, assuming zero ops") + return zero_ops(m, x, y) + + if m.mode == "nearest": + return zero_ops(m, x, y) + + total_ops = { + "linear": 5, + "bilinear": 11, + "bicubic": 259, # 224 muls + 35 adds + "trilinear": 31 # 2 * bilinear + 1 * linear + }.get(m.mode, 0) * y.nelement() + + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_linear(m, x, y): + """Counts operations for linear layers.""" + total_ops = m.in_features * y.numel() + m.total_ops += torch.DoubleTensor([total_ops]) diff --git a/EdgeFLite/thop/hooks_rnn.py b/EdgeFLite/thop/hooks_rnn.py new file mode 100644 index 0000000..700041d --- /dev/null +++ b/EdgeFLite/thop/hooks_rnn.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +def _count_rnn_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single RNN cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the RNN cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the RNN cell. + """ + ops = hidden_size * (input_size + hidden_size) + hidden_size + if bias: + ops += hidden_size * 2 + return ops + +def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor): + """Count operations for the RNNCell over a batch of input. + + Args: + cell (nn.RNNCell): The RNNCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_gru_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single GRU cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the GRU cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the GRU cell. + """ + ops = (hidden_size + input_size) * hidden_size + hidden_size + if bias: + ops += hidden_size * 2 + ops *= 2 # For reset and update gates + + ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate + if bias: + ops += hidden_size * 2 + ops += hidden_size # Hadamard product + ops += hidden_size * 3 # Final output + + return ops + +def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor): + """Count operations for the GRUCell over a batch of input. + + Args: + cell (nn.GRUCell): The GRUCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_lstm_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single LSTM cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the LSTM cell. + """ + ops = (input_size + hidden_size) * hidden_size + hidden_size + if bias: + ops += hidden_size * 2 + ops *= 4 # For input, forget, output, and cell gates + + ops += hidden_size * 3 # Cell state update + ops += hidden_size # Final output + + return ops + +def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor): + """Count operations for the LSTMCell over a batch of input. + + Args: + cell (nn.LSTMCell): The LSTMCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size): + """Calculate the total operations for RNN layers. + + Args: + model (nn.RNN): The RNN model. + num_layers (int): Number of layers in the RNN. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the RNN layers. + """ + ops = _count_rnn_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_rnn(model: nn.RNN, x: torch.Tensor): + """Count operations for the entire RNN over a batch of input. + + Args: + model (nn.RNN): The RNN model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size): + """Calculate the total operations for GRU layers. + + Args: + model (nn.GRU): The GRU model. + num_layers (int): Number of layers in the GRU. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the GRU layers. + """ + ops = _count_gru_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_gru(model: nn.GRU, x: torch.Tensor): + """Count operations for the entire GRU over a batch of input. + + Args: + model (nn.GRU): The GRU model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size): + """Calculate the total operations for LSTM layers. + + Args: + model (nn.LSTM): The LSTM model. + num_layers (int): Number of layers in the LSTM. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the LSTM layers. + """ + ops = _count_lstm_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_lstm(model: nn.LSTM, x: torch.Tensor): + """Count operations for the entire LSTM over a batch of input. + + Args: + model (nn.LSTM): The LSTM model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) diff --git a/EdgeFLite/thop/profiling.py b/EdgeFLite/thop/profiling.py new file mode 100644 index 0000000..158ec8d --- /dev/null +++ b/EdgeFLite/thop/profiling.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Importing necessary modules +from distutils.version import LooseVersion # Used for version comparisons +from .basic_hooks import * # Importing basic hooks (functions for profiling operations) +from .rnn_hooks import * # Importing hooks specific to RNN operations + +# Uncomment the following for logging purposes +# import logging +# logger = logging.getLogger(__name__) # Creating a logger instance +# logger.setLevel(logging.INFO) # Setting the log level to INFO + +# Functions to print text in different colors +# Useful for visually differentiating output in terminal +def prRed(skk): + print("\033[91m{}\033[00m".format(skk)) # Print red text +def prGreen(skk): + print("\033[92m{}\033[00m".format(skk)) # Print green text +def prYellow(skk): + print("\033[93m{}\033[00m".format(skk)) # Print yellow text + +# Checking if the installed version of PyTorch is outdated +if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): + # If the version is below 1.0.0, print a warning + logging.warning( + f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future." + ) + +# Setting the default data type for tensors +default_dtype = torch.float64 # Using 64-bit float as the default precision + +# Register hooks for different layers in PyTorch +# Each layer type is mapped to its respective counting function +register_hooks = { + nn.ZeroPad2d: zero_ops, + nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd, + nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd, + nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn, + nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu, + nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops, + nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops, + nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool, + nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool, + nn.Linear: count_linear, nn.Dropout: zero_ops, + nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample, + nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell, + nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, +} + +# Function for profiling model operations and parameters +# This tracks how many operations (ops) and parameters (params) a model uses +def profile_origin(model, inputs, custom_ops=None, verbose=True): + handler_collection = [] # Collection of hooks + types_collection = set() # Keep track of registered layer types + custom_ops = custom_ops or {} # Custom operation handling + + def add_hooks(m): + # Ignore compound modules (those that contain other modules) + if len(list(m.children())) > 0: + return + + # Check if the module already has the required attributes + if hasattr(m, "total_ops") or hasattr(m, "total_params"): + logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.") + + # Add buffers to store the total number of operations and parameters + m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) + m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) + + # Count the number of parameters for this module + for p in m.parameters(): + m.total_params += torch.DoubleTensor([p.numel()]) + + # Determine which function to use for counting operations + m_type = type(m) + fn = custom_ops.get(m_type, register_hooks.get(m_type, None)) + + if fn: + # If the function exists, register the forward hook + if m_type not in types_collection and verbose: + print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.") + handler = m.register_forward_hook(fn) + handler_collection.append(handler) + else: + # Warn if no counting rule is found + if m_type not in types_collection and verbose: + prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.") + + types_collection.add(m_type) + + # Set the model to evaluation mode (no gradients) + model.eval() + model.apply(add_hooks) + + # Run a forward pass with no gradients + with torch.no_grad(): + model(*inputs) + + # Sum up the total operations and parameters across all layers + total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops')) + total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params')) + + # Restore the model to training mode and remove hooks + model.train() + for handler in handler_collection: + handler.remove() + for m in model.modules(): + if hasattr(m, "total_ops"): del m._buffers['total_ops'] + if hasattr(m, "total_params"): del m._buffers['total_params'] + + return total_ops, total_params # Return the total number of ops and params + +# Updated profiling function with a different approach for hierarchical modules +def profile(model: nn.Module, inputs, custom_ops=None, verbose=True): + handler_collection = {} # Dictionary to store handlers + types_collection = set() # Store layer types that have been processed + custom_ops = custom_ops or {} # Custom operation handling + + def add_hooks(m: nn.Module): + # Add buffers for storing total ops and params + m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) + m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) + + # Find the appropriate counting function for this layer + fn = custom_ops.get(type(m), register_hooks.get(type(m), None)) + if fn: + # Register hooks for both operations and parameters + handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters)) + if type(m) not in types_collection and verbose: + print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.") + else: + # Warn if no rule is found for this layer + if type(m) not in types_collection and verbose: + prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.") + types_collection.add(type(m)) + + # Set the model to evaluation mode + model.eval() + model.apply(add_hooks) + + # Run a forward pass with no gradients + with torch.no_grad(): + model(*inputs) + + # Recursive function to count ops and params for hierarchical models + def dfs_count(module: nn.Module) -> (int, int): + total_ops, total_params = 0, 0 + for m in module.children(): + if m in handler_collection: + m_ops, m_params = m.total_ops.item(), m.total_params.item() + else: + m_ops, m_params = dfs_count(m) + total_ops += m_ops + total_params += m_params + return total_ops, total_params + + total_ops, total_params = dfs_count(model) # Perform the depth-first count + + # Restore the model to training mode and remove hooks + model.train() + for m, (op_handler, params_handler) in handler_collection.items(): + op_handler.remove() + params_handler.remove() + del m._buffers['total_ops'] + del m._buffers['total_params'] + + return total_ops, total_params # Return the total ops and params diff --git a/EdgeFLite/train_EdgeFLite.py b/EdgeFLite/train_EdgeFLite.py new file mode 100644 index 0000000..c8cb6e8 --- /dev/null +++ b/EdgeFLite/train_EdgeFLite.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import argparse +import torch.nn as nn +from config import * # Import configuration +from params import train_params # Import training parameters +from model import coremodel, coremodelsl # Import models +from utils import ( # Import utility functions + label_smoothing, norm, metric, lr_scheduler, prefetch, + save_hp_to_json, profile, clever_format +) +from dataset import factory # Import dataset factory + +# Specify the GPU to be used +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Global variable for tracking the best accuracy +best_acc1 = 0 + +# Function to calculate the average of a list of values +def average(values): + """Calculate average of a list.""" + return sum(values) / len(values) + +# Function to aggregate the models from multiple clients into a global model +def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models): + """Aggregates weights of the models using simple mean.""" + # Get the state dictionaries for the global models + global_main_state = global_model_main.state_dict() + global_proxy_state = global_model_proxy.state_dict() + + # Aggregate the main client models by averaging the weights + for key in global_main_state.keys(): + global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0) + global_model_main.load_state_dict(global_main_state) + + # Aggregate the proxy client models similarly + for key in global_proxy_state.keys(): + global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0) + global_model_proxy.load_state_dict(global_proxy_state) + + # Synchronize the client models with the updated global model + for client in client_main_models: + client.load_state_dict(global_model_main.state_dict()) + for client in client_proxy_models: + client.load_state_dict(global_model_proxy.state_dict()) + +# Function to perform client-side training updates +def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None): + """Client-side training update.""" + main_model.train() + proxy_models.train() + + # Train for a given number of epochs + for epoch in range(epochs): + # Prefetch data for faster loading + prefetcher = prefetch.data_prefetcher(train_loader) + images, targets = prefetcher.next() + batch_idx = 0 + + # Zero the gradients + optimizers_main.zero_grad() + optimizers_proxy.zero_grad() + + # Process each batch of data + while images is not None: + # Adjust learning rates using the scheduler + schedulers_main(optimizers_main, batch_idx, round_idx) + schedulers_proxy(optimizers_proxy, batch_idx, round_idx) + + # Forward pass for the main model + outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams) + main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] + + # Forward pass for the proxy model with outputs from the main model + ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams) + + # Calculate total loss and perform backpropagation + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + total_loss.backward() + + # Backpropagate gradients for the main model + for j in range(len(main_fx)): + outputs[j].backward(main_fx[j].grad) + + # Update the model weights periodically + if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader): + optimizers_main.step() + optimizers_main.zero_grad() + optimizers_proxy.step() + optimizers_proxy.zero_grad() + + # Fetch the next batch of images + images, targets = prefetcher.next() + batch_idx += 1 + + return total_loss.item() + +# Function to validate the models on a validation set +def validate(val_loader, main_model, proxy_models, args, streams=None): + """Validation function to evaluate models.""" + main_model.eval() + proxy_models.eval() + + # Initialize metrics for accuracy tracking + top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)] + acc1_list, acc5_list, ce_loss_list = [], [], [] + + # Disable gradient computation for validation + with torch.no_grad(): + for images, targets in val_loader: + images, targets = images.cuda(), targets.cuda() + + # Forward pass for main model + outputs = main_model(images, target=targets, mode='val') + main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] + + # Forward pass for proxy model + ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val') + + # Calculate accuracy + acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5)) + acc1_list.append(acc1) + acc5_list.append(acc5) + ce_loss_list.append(ce_loss) + + # Calculate average metrics over the validation set + avg_acc1 = average(acc1_list) + avg_acc5 = average(acc5_list) + avg_ce_loss = average(ce_loss_list) + + return avg_ce_loss, avg_acc1, top1_metrics + +# Main function to set up and start decentralized training +def main(args): + """Main function to set up decentralized training.""" + # Set loop factor based on training configuration + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + # Determine if decentralized training is needed + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized + + # Get the number of GPUs available + ngpus_per_node = torch.cuda.device_count() + args.ngpus_per_node = ngpus_per_node + + # If using decentralized training with multiprocessing + if args.multiprocessing_decentralized: + args.world_size *= ngpus_per_node + torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # If not using multiprocessing, proceed with a single GPU + args.gpu = 0 + execute_worker_process(args.gpu, ngpus_per_node, args) + +# Main worker function to handle training with multiple GPUs or single GPU +def execute_worker_process(gpu, ngpus_per_node, args): + """Main worker function for multi-GPU or single-GPU training.""" + global best_acc1 + args.gpu = gpu + + # Set process title + setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}") + + # Set the criterion for loss calculation + if args.is_label_smoothing: + criterion = label_smoothing.label_smoothing_CE(reduction='mean') + else: + criterion = nn.CrossEntropyLoss() + + # Create the main and proxy models for training + main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() + proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() + + # Initialize client models for federated learning + client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] + client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] + + # Synchronize client models with the global models + for client in client_main_models: + client.load_state_dict(main_model.state_dict()) + for client in client_proxy_models: + client.load_state_dict(proxy_model.state_dict()) + + # Load training and validation data + train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers) + val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers) + + # Loop over training rounds + for r in range(args.start_round, args.num_rounds + 1): + # Update client models with new training data + client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader) + + # Validate the models + test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args) + + # Track the best accuracy achieved + best_acc1 = max(acc, best_acc1) + +# Entry point for the script +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Training EdgeFLite") + args = train_params.add_parser_params(parser) + main(args)