diff --git a/Task_1/FeTS_Challenge.py b/Task_1/FeTS_Challenge.py index 94d7598..e29cdb1 100644 --- a/Task_1/FeTS_Challenge.py +++ b/Task_1/FeTS_Challenge.py @@ -14,7 +14,10 @@ import os import numpy as np - +from fets_challenge import model_outputs_to_disc +from pathlib import Path +import shutil +import glob from fets_challenge import run_challenge_experiment @@ -333,15 +336,12 @@ def clipped_aggregation(local_tensors, clip_to_percentile = 80 # first, we need to determine how much each local update has changed the tensor from the previous value - # we'll use the tensor_db search function to find the - previous_tensor_value = tensor_db.search(tensor_name=tensor_name, fl_round=fl_round, tags=('model',), origin='aggregator') + # we'll use the tensor_db retrieve function to find the previous tensor value + previous_tensor_value = tensor_db.retrieve(tensor_name=tensor_name, origin='aggregator', fl_round=fl_round - 1, tags=('aggregated',)) - if previous_tensor_value.shape[0] > 1: - print(previous_tensor_value) - raise ValueError(f'found multiple matching tensors for {tensor_name}, tags=(model,), origin=aggregator') - - if previous_tensor_value.shape[0] < 1: + if previous_tensor_value is None: # no previous tensor, so just return the weighted average + logger.info(f"previous_tensor_value is None") return weighted_average_aggregation(local_tensors, tensor_db, tensor_name, @@ -349,8 +349,6 @@ def clipped_aggregation(local_tensors, collaborators_chosen_each_round, collaborator_times_per_round) - previous_tensor_value = previous_tensor_value.nparray.iloc[0] - # compute the deltas for each collaborator deltas = [t.tensor - previous_tensor_value for t in local_tensors] @@ -423,19 +421,20 @@ def FedAvgM_Selection(local_tensors, if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']: #weight_speeds[tensor_name] = np.zeros_like(local_tensors[0].tensor) # weight_speeds[tensor_name] = np.zeros(local_tensors[0].tensor.shape) tensor_db.store( - tensor_name=tensor_name, + tensor_name=tensor_name, tags=('weight_speeds',), nparray=np.zeros_like(local_tensors[0].tensor), ) + return new_tensor_weight else: if tensor_name.endswith("weight") or tensor_name.endswith("bias"): # Calculate aggregator's last value previous_tensor_value = None for _, record in tensor_db.iterrows(): - if (record['round'] == fl_round + if (record['round'] == fl_round - 1 # Fetching aggregated value for previous round and record["tensor_name"] == tensor_name - and record["tags"] == ("aggregated",)): + and record["tags"] == ('aggregated',)): previous_tensor_value = record['nparray'] break @@ -450,7 +449,7 @@ def FedAvgM_Selection(local_tensors, if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']: tensor_db.store( - tensor_name=tensor_name, + tensor_name=tensor_name, tags=('weight_speeds',), nparray=np.zeros_like(local_tensors[0].tensor), ) @@ -474,7 +473,7 @@ def FedAvgM_Selection(local_tensors, new_tensor_weight_speed = momentum * tensor_weight_speed + average_deltas # fix delete (1-momentum) tensor_db.store( - tensor_name=tensor_name, + tensor_name=tensor_name, tags=('weight_speeds',), nparray=new_tensor_weight_speed ) @@ -530,7 +529,7 @@ def FedAvgM_Selection(local_tensors, # increase this if you need a longer history for your algorithms # decrease this if you need to reduce system RAM consumption -db_store_rounds = 5 +db_store_rounds = 1 # this is passed to PyTorch, so set it accordingly for your system device = 'cpu' @@ -543,71 +542,91 @@ def FedAvgM_Selection(local_tensors, # The checkpoints can grow quite large (5-10GB) so only the latest will be saved when this parameter is enabled save_checkpoints = True +# (str) Determines the backend process to use for the experiment.(single_process, ray) +backend_process = 'single_process' + # path to previous checkpoint folder for experiment that was stopped before completion. -# Checkpoints are stored in ~/.local/workspace/checkpoint, and you should provide the experiment directory +# Checkpoints are stored in ~/.local/workspace/checkpoint, and you should provide the experiment directory # relative to this path (i.e. 'experiment_1'). Please note that if you restore from a checkpoint, # and save checkpoint is set to True, then the checkpoint you restore from will be subsequently overwritten. # restore_from_checkpoint_folder = 'experiment_1' restore_from_checkpoint_folder = None +# infer participant home folder +home = str(Path.home()) -# the scores are returned in a Pandas dataframe -scores_dataframe, checkpoint_folder = run_challenge_experiment( +#Creating working directory and copying the required csv files +working_directory= os.path.join(home, '.local/workspace/') +Path(working_directory).mkdir(parents=True, exist_ok=True) +source_dir=f'{Path.cwd()}/partitioning_data/' +pattern = "*.csv" +source_pattern = os.path.join(source_dir, pattern) +files_to_copy = glob.glob(source_pattern) + +if not files_to_copy: + logger.info(f"No files found matching pattern: {pattern}") + +for source_file in files_to_copy: + destination_file = os.path.join(working_directory, os.path.basename(source_file)) + shutil.copy2(source_file, destination_file) +try: + os.chdir(working_directory) + logger.info(f"Directory changed to : {os.getcwd()}") +except FileNotFoundError: + logger.info("Error: Directory not found.") +except PermissionError: + logger.info("Error: Permission denied") + +checkpoint_folder = run_challenge_experiment( aggregation_function=aggregation_function, choose_training_collaborators=choose_training_collaborators, training_hyper_parameters_for_round=training_hyper_parameters_for_round, - include_validation_with_hausdorff=include_validation_with_hausdorff, institution_split_csv_filename=institution_split_csv_filename, brats_training_data_parent_dir=brats_training_data_parent_dir, db_store_rounds=db_store_rounds, rounds_to_train=rounds_to_train, device=device, save_checkpoints=save_checkpoints, - restore_from_checkpoint_folder = restore_from_checkpoint_folder) - - -scores_dataframe + restore_from_checkpoint_folder = restore_from_checkpoint_folder, + include_validation_with_hausdorff=include_validation_with_hausdorff, + backend_process = backend_process) # ## Produce NIfTI files for best model outputs on the validation set # Now we will produce model outputs to submit to the leader board. # # At the end of every experiment, the best model (according to average ET, TC, WT DICE) -# is saved to disk at: ~/.local/workspace/checkpoint/\/best_model.pkl, +# is saved to disk at: ~/.local/workspace/checkpoint/checkpoint/\/best_model.pkl, # where \ is the one printed to stdout during the start of the # experiment (look for the log entry: "Created experiment folder experiment_##..." above). - -from fets_challenge import model_outputs_to_disc -from pathlib import Path - -# infer participant home folder -home = str(Path.home()) - # you will need to specify the correct experiment folder and the parent directory for # the data you want to run inference over (assumed to be the experiment that just completed) -#checkpoint_folder='experiment_1' #data_path = -data_path = '/home/brats/MICCAI_FeTS2022_ValidationData' +data_path = '/raid/datasets/FeTS22/MICCAI_FeTS2022_ValidationData' validation_csv_filename = 'validation.csv' # you can keep these the same if you wish -final_model_path = os.path.join(home, '.local/workspace/checkpoint', checkpoint_folder, 'best_model.pkl') +if checkpoint_folder is not None: + final_model_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'best_model.pkl') +else: + exit("No checkpoint folder found. Please provide a valid checkpoint folder. Exiting the experiment without inferencing") # If the experiment is only run for a single round, use the temp model instead if not Path(final_model_path).exists(): - final_model_path = os.path.join(home, '.local/workspace/checkpoint', checkpoint_folder, 'temp_model.pkl') + final_model_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'temp_model.pkl') -outputs_path = os.path.join(home, '.local/workspace/checkpoint', checkpoint_folder, 'model_outputs') +if not Path(final_model_path).exists(): + exit("No model found. Please provide a valid checkpoint folder. Exiting the experiment without inferencing") +outputs_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'model_outputs') # Using this best model, we can now produce NIfTI files for model outputs # using a provided data directory - model_outputs_to_disc(data_path=data_path, validation_csv=validation_csv_filename, output_path=outputs_path, native_model_path=final_model_path, outputtag='', - device=device) + device=device) \ No newline at end of file diff --git a/Task_1/README.md b/Task_1/README.md index ab80043..96972d4 100644 --- a/Task_1/README.md +++ b/Task_1/README.md @@ -20,17 +20,16 @@ Please ask any additional questions in our discussion pages on our github site a 2. ```git clone https://github.com/FETS-AI/Challenge.git``` 3. ```cd Challenge/Task_1``` 4. ```git lfs pull``` -5. Create virtual environment (python 3.6-3.8): using Anaconda, a new environment can be created and activated using the following commands: +5. Create virtual environment (python 3.10-3.13): using python venv, a new environment can be created and activated using the following commands: ```sh ## create venv in specific path - conda create -p ./venv python=3.7 -y - conda activate ./venv + python -m venv venv + source venv/bin/activate ``` 6. ```pip install --upgrade pip``` -7. Install Pytorch LTS (1.8.2) for your system (use CUDA 11): - ```pip3 install torch==1.8.2 torchvision==0.9.2 torchaudio==0.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111``` -*Note all previous versions of pytorch can be found in [these instructions]([https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/previous-versions/)) -9. Set the environment variable `SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True` (to avoid sklearn deprecation error) +7. Install dependent pip libraries: + ```pip install -r requirements.txt``` +9. Set the environment variable `export SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True` (to avoid sklearn deprecation error) 10. ```pip install .``` > * _Note: if you run into ```ERROR: Failed building wheel for SimpleITK```, try running ```pip install SimpleITK --only-binary :all:``` then rerunning ```pip install .```_ 10. ```python FeTS_Challenge.py``` diff --git a/Task_1/fets_challenge/checkpoint_utils.py b/Task_1/fets_challenge/checkpoint_utils.py index 30d5706..9ee8c00 100644 --- a/Task_1/fets_challenge/checkpoint_utils.py +++ b/Task_1/fets_challenge/checkpoint_utils.py @@ -21,7 +21,7 @@ def setup_checkpoint_folder(): Path(checkpoint_folder).mkdir(parents=True, exist_ok=False) return experiment_folder -def save_checkpoint(checkpoint_folder, aggregator, +def save_checkpoint(checkpoint_folder, agg_tensor_db, collaborator_names, collaborators, round_num, collaborator_time_stats, total_simulated_time, best_dice, @@ -34,7 +34,7 @@ def save_checkpoint(checkpoint_folder, aggregator, Save latest checkpoint """ # Save aggregator tensor_db - aggregator.tensor_db.tensor_db.to_pickle(f'checkpoint/{checkpoint_folder}/aggregator_tensor_db.pkl') + agg_tensor_db.tensor_db.to_pickle(f'checkpoint/{checkpoint_folder}/aggregator_tensor_db.pkl') with open(f'checkpoint/{checkpoint_folder}/state.pkl', 'wb') as f: pickle.dump([collaborator_names, round_num, collaborator_time_stats, total_simulated_time, best_dice, best_dice_over_time_auc, collaborators_chosen_each_round, diff --git a/Task_1/fets_challenge/config/gandlf_config.yaml b/Task_1/fets_challenge/config/gandlf_config.yaml new file mode 100644 index 0000000..4be0b2a --- /dev/null +++ b/Task_1/fets_challenge/config/gandlf_config.yaml @@ -0,0 +1,66 @@ +batch_size: 1 +clip_grad: null +clip_mode: null +data_augmentation: {} +data_postprocessing: {} +data_preprocessing: + normalize: null +enable_padding: false +in_memory: false +inference_mechanism : + grid_aggregator_overlap: crop + patch_overlap: 0 +learning_rate: 0.001 +loss_function: dc +medcam_enabled: false +output_dir: '.' +metrics: +- dice +- dice_per_label +- hd95_per_label +model: + amp: true + architecture: resunet + base_filters: 32 + class_list: + - 0 + - 1 + - 2 + - 4 + dimension: 3 + final_layer: softmax + ignore_label_validation: null + norm_type: instance + num_channels: 4 +nested_training: + testing: 1 + validation: -5 +num_epochs: 1 +optimizer: + type: sgd +parallel_compute_command: '' +patch_sampler: label +patch_size: +- 64 +- 64 +- 64 +patience: 100 +pin_memory_dataloader: false +print_rgb_label_warning: true +q_max_length: 100 +q_num_workers: 0 +q_samples_per_volume: 40 +q_verbose: false +save_output: false +save_training: false +scaling_factor: 1 +scheduler: + type: triangle_modified +track_memory_usage: false +verbose: False +version: + maximum: 0.1.0 + minimum: 0.0.14 +weighted_loss: true +modality: rad +problem_type: segmentation \ No newline at end of file diff --git a/Task_1/fets_challenge/custom_aggregation_wrapper.py b/Task_1/fets_challenge/custom_aggregation_wrapper.py index ae7abc4..63472a6 100644 --- a/Task_1/fets_challenge/custom_aggregation_wrapper.py +++ b/Task_1/fets_challenge/custom_aggregation_wrapper.py @@ -1,4 +1,4 @@ -from openfl.component.aggregation_functions.experimental import PrivilegedAggregationFunction +from openfl.interface.aggregation_functions.experimental import PrivilegedAggregationFunction # extends the openfl agg func interface to include challenge-relevant information diff --git a/Task_1/fets_challenge/experiment.py b/Task_1/fets_challenge/experiment.py index f561e66..8d9a7f1 100644 --- a/Task_1/fets_challenge/experiment.py +++ b/Task_1/fets_challenge/experiment.py @@ -4,221 +4,45 @@ # Patrick Foley (Intel), Micah Sheller (Intel) import os +from copy import deepcopy import warnings -from collections import namedtuple -from copy import copy -import shutil from logging import getLogger from pathlib import Path - -import numpy as np -import pandas as pd -from openfl.utilities import split_tensor_dict_for_holdouts, TensorKey -from openfl.protocols import utils -import openfl.native as fx -import torch +from torch.utils.data import DataLoader from .gandlf_csv_adapter import construct_fedsim_csv, extract_csv_partitions from .custom_aggregation_wrapper import CustomAggregationWrapper -from .checkpoint_utils import setup_checkpoint_folder, save_checkpoint, load_checkpoint - -# one week -# MINUTE = 60 -# HOUR = 60 * MINUTE -# DAY = 24 * HOUR -# WEEK = 7 * DAY -MAX_SIMULATION_TIME = 7 * 24 * 60 * 60 - -## COLLABORATOR TIMING DISTRIBUTIONS -# These data are derived from the actual timing information in the real-world FeTS information -# They reflect a subset of the institutions involved. -# Tuples are (mean, stddev) in seconds -# time to train one patient -TRAINING_TIMES = [(6.710741331207654, 0.8726112813698301), - (2.7343911917098445, 0.023976155580152165), - (3.173076923076923, 0.04154320960517865), - (6.580379746835443, 0.22461890673025595), - (3.452046783625731, 0.47136389322749656), - (6.090788461700995, 0.08541499003440205), - (3.206933911159263, 0.1927067498514361), - (3.3358208955223883, 0.2950567549663471), - (4.391304347826087, 0.37464538999161057), - (6.324805129494594, 0.1413885448869165), - (7.415133477633478, 1.1198881747151301), - (5.806410256410255, 0.029926699295169234), - (6.300204918032787, 0.24932319729777577), - (5.886317567567567, 0.018627858809133223), - (5.478184991273998, 0.04902740607167421), - (6.32440159574468, 0.15838847558954935), - (20.661918328585003, 6.085405543890793), - (3.197901325478645, 0.07049966132127056), - (6.523963730569948, 0.2533266757118492), - (2.6540077569489338, 0.025503099659276184), - (1.8025746183640918, 0.06805805332403576)] +from .fets_flow import FeTSFederatedFlow +from .fets_challenge_model import FeTSChallengeModel +from .fets_data_loader import FeTSDataLoader -# time to validate one patient -VALIDATION_TIMES = [(23.129135113591072, 2.5975116854269507), - (12.965544041450777, 0.3476297824941513), - (14.782051282051283, 0.5262660449172765), - (16.444936708860762, 0.42613177203005187), - (15.728654970760235, 4.327559980390658), - (12.946098012884802, 0.2449927822869217), - (15.335950126991456, 1.1587597276712558), - (24.024875621890544, 3.087348297794285), - (38.361702127659576, 2.240113332190875), - (16.320970580839827, 0.4995108101783225), - (30.805555555555554, 3.1836337269688237), - (12.100899742930592, 0.41122386959584895), - (13.099897540983607, 0.6693132795197584), - (9.690202702702702, 0.17513593019922968), - (10.06980802792321, 0.7947848617875114), - (14.605333333333334, 0.6012305898922827), - (36.30294396961064, 9.24123672148819), - (16.9130060292851, 0.7452868131028928), - (40.244078460399706, 3.7700993678269037), - (13.161603102779575, 0.1975347910041472), - (11.222161868549701, 0.7021223062972527)] +from openfl.experimental.workflow.interface import Aggregator, Collaborator +from openfl.experimental.workflow.runtime import LocalRuntime -# time to download the model -DOWNLOAD_TIMES = [(112.42869743589742, 14.456734719659513), - (117.26870618556701, 12.549951446132013), - (13.059666666666667, 4.8700489616521185), - (47.50220338983051, 14.92128656898884), - (162.27864210526315, 32.562113378948396), - (99.46072058823529, 13.808785580783224), - (33.6347090909091, 25.00299299660141), - (216.25489393939392, 19.176465340447848), - (217.4117230769231, 20.757673955585453), - (98.38857297297298, 13.205048376808929), - (88.87509473684209, 23.152936862511545), - (66.96994262295081, 16.682497150763503), - (36.668852040816326, 13.759109844677598), - (149.31716326530614, 26.018185409516104), - (139.847, 80.04755583050091), - (54.97624444444445, 16.645170929316794)] - -# time to upload the model -UPLOAD_TIMES = [(192.28497409326425, 21.537450985376967), - (194.60103626943004, 24.194406902237056), - (20.0, 0.0), - (52.43859649122807, 5.047207127169352), - (182.82417582417582, 14.793519078918195), - (143.38059701492537, 7.910690646792151), - (30.695652173913043, 9.668122350904568), - (430.95360824742266, 54.97790476867727), - (348.3174603174603, 30.14347985347738), - (141.43715846994536, 5.271340868190727), - (158.7433155080214, 64.87526819391198), - (81.06086956521739, 7.003461202082419), - (32.60621761658031, 5.0418315093016615), - (281.5388601036269, 90.60338778706557), - (194.34065934065933, 36.6519776778435), - (66.53787878787878, 16.456280602190606)] +from GANDLF.config_manager import ConfigManager logger = getLogger(__name__) # This catches PyTorch UserWarnings for CPU warnings.filterwarnings("ignore", category=UserWarning) -CollaboratorTimeStats = namedtuple('CollaboratorTimeStats', - [ - 'validation_mean', - 'training_mean', - 'download_speed_mean', - 'upload_speed_mean', - 'validation_std', - 'training_std', - 'download_speed_std', - 'upload_speed_std', - ] - ) - -def gen_collaborator_time_stats(collaborator_names, seed=0xFEEDFACE): - - np.random.seed(seed) - - stats = {} - for col in collaborator_names: - ml_index = np.random.randint(len(VALIDATION_TIMES)) - validation = VALIDATION_TIMES[ml_index] - training = TRAINING_TIMES[ml_index] - net_index = np.random.randint(len(DOWNLOAD_TIMES)) - download = DOWNLOAD_TIMES[net_index] - upload = UPLOAD_TIMES[net_index] - - stats[col] = CollaboratorTimeStats(validation_mean=validation[0], - training_mean=training[0], - download_speed_mean=download[0], - upload_speed_mean=upload[0], - validation_std=validation[1], - training_std=training[1], - download_speed_std=download[1], - upload_speed_std=upload[1]) - return stats - -def compute_times_per_collaborator(collaborator_names, - training_collaborators, - epochs_per_round, - collaborator_data, - collaborator_time_stats, - round_num): - np.random.seed(round_num) - times = {} - for col in collaborator_names: - time = 0 - - # stats - stats = collaborator_time_stats[col] - - # download time - download_time = np.random.normal(loc=stats.download_speed_mean, - scale=stats.download_speed_std) - download_time = max(1, download_time) - time += download_time - - # data loader - data = collaborator_data[col] - - # validation time - data_size = data.get_valid_data_size() - validation_time_per = np.random.normal(loc=stats.validation_mean, - scale=stats.validation_std) - validation_time_per = max(1, validation_time_per) - time += data_size * validation_time_per - - # only if training - if col in training_collaborators: - # training time - data_size = data.get_train_data_size() - training_time_per = np.random.normal(loc=stats.training_mean, - scale=stats.training_std) - training_time_per = max(1, training_time_per) - - # training data size depends on the hparams - data_size *= epochs_per_round - time += data_size * training_time_per - - # if training, we also validate the locally updated model - data_size = data.get_valid_data_size() - validation_time_per = np.random.normal(loc=stats.validation_mean, - scale=stats.validation_std) - validation_time_per = max(1, validation_time_per) - time += data_size * validation_time_per +def aggregator_private_attributes(aggregation_type, collaborator_names, db_store_rounds): + return { + "aggregation_type" : aggregation_type, + "collaborator_names": collaborator_names, + "checkpoint_folder":None, + "db_store_rounds":db_store_rounds, + "agg_tensor_dict":{} + } - # upload time - upload_time = np.random.normal(loc=stats.upload_speed_mean, - scale=stats.upload_speed_std) - upload_time = max(1, upload_time) - time += upload_time - - times[col] = time - return times +def collaborator_private_attributes(index, train_csv_path, val_csv_path): + return { + "index": index, + "train_csv_path": train_csv_path, + "val_csv_path": val_csv_path + } -def get_metric(metric, fl_round, tensor_db): - metric_name = metric - target_tags = ('metric', 'validate_agg') - return float(tensor_db.tensor_db.query("tensor_name == @metric_name and round == @fl_round and tags == @target_tags").nparray) def run_challenge_experiment(aggregation_function, choose_training_collaborators, @@ -231,333 +55,94 @@ def run_challenge_experiment(aggregation_function, save_checkpoints=True, restore_from_checkpoint_folder=None, include_validation_with_hausdorff=True, - use_pretrained_model=True): - - fx.init('fets_challenge_workspace') - - from sys import path, exit + use_pretrained_model=False, + backend_process='single_process'): file = Path(__file__).resolve() root = file.parent.resolve() # interface root, containing command modules work = Path.cwd().resolve() - - path.append(str(root)) - path.insert(0, str(work)) + gandlf_config_path = os.path.join(root, 'config', 'gandlf_config.yaml') # create gandlf_csv and get collaborator names gandlf_csv_path = os.path.join(work, 'gandlf_paths.csv') - # split_csv_path = os.path.join(work, institution_split_csv_filename) + split_csv_path = os.path.join(work, institution_split_csv_filename) collaborator_names = construct_fedsim_csv(brats_training_data_parent_dir, - institution_split_csv_filename, + split_csv_path, 0.8, gandlf_csv_path) + + logger.info(f'Collaborator names for experiment : {collaborator_names}') aggregation_wrapper = CustomAggregationWrapper(aggregation_function) - overrides = { - 'aggregator.settings.rounds_to_train': rounds_to_train, - 'aggregator.settings.db_store_rounds': db_store_rounds, - 'tasks.train.aggregation_type': aggregation_wrapper, - 'task_runner.settings.device': device, - } - - - # Update the plan if necessary - plan = fx.update_plan(overrides) - - if not include_validation_with_hausdorff: - plan.config['task_runner']['settings']['fets_config_dict']['metrics'] = ['dice','dice_per_label'] - - # Overwrite collaborator names - plan.authorized_cols = collaborator_names - # overwrite datapath values with the collaborator name itself - for col in collaborator_names: - plan.cols_data_paths[col] = col - - # get the data loaders for each collaborator - collaborator_data_loaders = {col: copy(plan).get_data_loader(col) for col in collaborator_names} - transformed_csv_dict = extract_csv_partitions(os.path.join(work, 'gandlf_paths.csv')) - # get the task runner, passing the first data loader - for col in collaborator_data_loaders: - #Insert logic to serialize train / val CSVs here - transformed_csv_dict[col]['train'].to_csv(os.path.join(work, 'seg_test_train.csv')) - transformed_csv_dict[col]['val'].to_csv(os.path.join(work, 'seg_test_val.csv')) - task_runner = copy(plan).get_task_runner(collaborator_data_loaders[col]) - - if use_pretrained_model: - print('Loading pretrained model...') - if device == 'cpu': - checkpoint = torch.load(f'{root}/pretrained_model/resunet_pretrained.pth',map_location=torch.device('cpu')) - task_runner.model.load_state_dict(checkpoint['model_state_dict']) - task_runner.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - else: - checkpoint = torch.load(f'{root}/pretrained_model/resunet_pretrained.pth') - task_runner.model.load_state_dict(checkpoint['model_state_dict']) - task_runner.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - tensor_pipe = plan.get_tensor_pipe() - - # Initialize model weights - init_state_path = plan.config['aggregator']['settings']['init_state_path'] - tensor_dict, _ = split_tensor_dict_for_holdouts(logger, task_runner.get_tensor_dict(False)) - - model_snap = utils.construct_model_proto(tensor_dict=tensor_dict, - round_number=0, - tensor_pipe=tensor_pipe) - - utils.dump_proto(model_proto=model_snap, fpath=init_state_path) - - # get the aggregator, now that we have the initial weights file set up - logger.info('Creating aggregator...') - aggregator = plan.get_aggregator() - # manually override the aggregator UUID (for checkpoint resume when rounds change) - aggregator.uuid = 'aggregator' - aggregator._load_initial_tensors() - - # create our collaborators - logger.info('Creating collaborators...') - collaborators = {col: copy(plan).get_collaborator(col, task_runner=task_runner, client=aggregator) for col in collaborator_names} - - collaborator_time_stats = gen_collaborator_time_stats(plan.authorized_cols) - collaborators_chosen_each_round = {} - collaborator_times_per_round = {} - - logger.info('Starting experiment') - - total_simulated_time = 0 - best_dice = -1.0 - best_dice_over_time_auc = 0 - - # results dataframe data - experiment_results = { - 'round':[], - 'time': [], - 'convergence_score': [], - 'round_dice': [], - 'dice_label_0': [], - 'dice_label_1': [], - 'dice_label_2': [], - 'dice_label_4': [], - } - if include_validation_with_hausdorff: - experiment_results.update({ - 'hausdorff95_label_0': [], - 'hausdorff95_label_1': [], - 'hausdorff95_label_2': [], - 'hausdorff95_label_4': [], - }) - - - if restore_from_checkpoint_folder is None: - checkpoint_folder = setup_checkpoint_folder() - logger.info(f'\nCreated experiment folder {checkpoint_folder}...') - starting_round_num = 0 + gandlf_conf = {} + if isinstance(gandlf_config_path, str) and os.path.exists(gandlf_config_path): + gandlf_conf = ConfigManager(gandlf_config_path) + elif isinstance(gandlf_config_path, dict): + gandlf_conf = gandlf_config_path else: - if not Path(f'checkpoint/{restore_from_checkpoint_folder}').exists(): - logger.warning(f'Could not find provided checkpoint folder: {restore_from_checkpoint_folder}. Exiting...') - exit(1) - else: - logger.info(f'Attempting to load last completed round from {restore_from_checkpoint_folder}') - state = load_checkpoint(restore_from_checkpoint_folder) - checkpoint_folder = restore_from_checkpoint_folder - - [loaded_collaborator_names, starting_round_num, collaborator_time_stats, - total_simulated_time, best_dice, best_dice_over_time_auc, - collaborators_chosen_each_round, collaborator_times_per_round, - experiment_results, summary, agg_tensor_db] = state - - if loaded_collaborator_names != collaborator_names: - logger.error(f'Collaborator names found in checkpoint ({loaded_collaborator_names}) ' - f'do not match provided collaborators ({collaborator_names})') - exit(1) - - logger.info(f'Previous summary for round {starting_round_num}') - logger.info(summary) - - starting_round_num += 1 - aggregator.tensor_db.tensor_db = agg_tensor_db - aggregator.round_number = starting_round_num - - - for round_num in range(starting_round_num, rounds_to_train): - # pick collaborators to train for the round - training_collaborators = choose_training_collaborators(collaborator_names, - aggregator.tensor_db._iterate(), - round_num, - collaborators_chosen_each_round, - collaborator_times_per_round) - - logger.info('Collaborators chosen to train for round {}:\n\t{}'.format(round_num, training_collaborators)) - - # save the collaborators chosen this round - collaborators_chosen_each_round[round_num] = training_collaborators - - # get the hyper-parameters from the competitor - hparams = training_hyper_parameters_for_round(collaborator_names, - aggregator.tensor_db._iterate(), - round_num, - collaborators_chosen_each_round, - collaborator_times_per_round) - - learning_rate, epochs_per_round = hparams - - if (epochs_per_round is None): - logger.warning('Hyper-parameter function warning: function returned None for "epochs_per_round". Setting "epochs_per_round" to 1') - epochs_per_round = 1 - - hparam_message = "\n\tlearning rate: {}".format(learning_rate) - - hparam_message += "\n\tepochs_per_round: {}".format(epochs_per_round) - - logger.info("Hyper-parameters for round {}:{}".format(round_num, hparam_message)) - - # cache each tensor in the aggregator tensor_db - hparam_dict = {} - tk = TensorKey(tensor_name='learning_rate', - origin=aggregator.uuid, - round_number=round_num, - report=False, - tags=('hparam', 'model')) - hparam_dict[tk] = np.array(learning_rate) - tk = TensorKey(tensor_name='epochs_per_round', - origin=aggregator.uuid, - round_number=round_num, - report=False, - tags=('hparam', 'model')) - hparam_dict[tk] = np.array(epochs_per_round) - aggregator.tensor_db.cache_tensor(hparam_dict) - - # pre-compute the times for each collaborator - times_per_collaborator = compute_times_per_collaborator(collaborator_names, - training_collaborators, - epochs_per_round, - collaborator_data_loaders, - collaborator_time_stats, - round_num) - collaborator_times_per_round[round_num] = times_per_collaborator - - aggregator.assigner.set_training_collaborators(training_collaborators) - - # update the state in the aggregation wrapper - aggregation_wrapper.set_state_data_for_round(collaborators_chosen_each_round, collaborator_times_per_round) - - # turn the times list into a list of tuples and sort it - times_list = [(t, col) for col, t in times_per_collaborator.items()] - times_list = sorted(times_list) - - # now call each collaborator in order of time - # FIXME: this doesn't break up each task. We need this if we're doing straggler handling - for t, col in times_list: - # set the task_runner data loader - task_runner.data_loader = collaborator_data_loaders[col] - - # run the collaborator - collaborators[col].run_simulation() - - logger.info("Collaborator {} took simulated time: {} minutes".format(col, round(t / 60, 2))) - - # the round time is the max of the times_list - round_time = max([t for t, _ in times_list]) - total_simulated_time += round_time - - - # get the performace validation scores for the round - round_dice = get_metric('valid_dice', round_num, aggregator.tensor_db) - dice_label_0 = get_metric('valid_dice_per_label_0', round_num, aggregator.tensor_db) - dice_label_1 = get_metric('valid_dice_per_label_1', round_num, aggregator.tensor_db) - dice_label_2 = get_metric('valid_dice_per_label_2', round_num, aggregator.tensor_db) - dice_label_4 = get_metric('valid_dice_per_label_4', round_num, aggregator.tensor_db) - if include_validation_with_hausdorff: - hausdorff95_label_0 = get_metric('valid_hd95_per_label_0', round_num, aggregator.tensor_db) - hausdorff95_label_1 = get_metric('valid_hd95_per_label_1', round_num, aggregator.tensor_db) - hausdorff95_label_2 = get_metric('valid_hd95_per_label_2', round_num, aggregator.tensor_db) - hausdorff95_label_4 = get_metric('valid_hd95_per_label_4', round_num, aggregator.tensor_db) - - # update best score - if best_dice < round_dice: - best_dice = round_dice - # Set the weights for the final model - if round_num == 0: - # here the initial model was validated (temp model does not exist) - logger.info(f'Skipping best model saving to disk as it is a random initialization.') - elif not os.path.exists(f'checkpoint/{checkpoint_folder}/temp_model.pkl'): - raise ValueError(f'Expected temporary model at: checkpoint/{checkpoint_folder}/temp_model.pkl to exist but it was not found.') - else: - # here the temp model was the one validated - shutil.copyfile(src=f'checkpoint/{checkpoint_folder}/temp_model.pkl',dst=f'checkpoint/{checkpoint_folder}/best_model.pkl') - logger.info(f'Saved model with best average binary DICE: {best_dice} to ~/.local/workspace/checkpoint/{checkpoint_folder}/best_model.pkl') - - ## RUN VALIDATION ON INTERMEDIATE CONSENSUS MODEL - # set the task_runner data loader - # task_runner.data_loader = collaborator_data_loaders[col] - - ## CONVERGENCE METRIC COMPUTATION - # update the auc score - best_dice_over_time_auc += best_dice * round_time - - # project the auc score as remaining time * best dice - # this projection assumes that the current best score is carried forward for the entire week - projected_auc = (MAX_SIMULATION_TIME - total_simulated_time) * best_dice + best_dice_over_time_auc - projected_auc /= MAX_SIMULATION_TIME - - # End of round summary - summary = '"**** END OF ROUND {} SUMMARY *****"'.format(round_num) - summary += "\n\tSimulation Time: {} minutes".format(round(total_simulated_time / 60, 2)) - summary += "\n\t(Projected) Convergence Score: {}".format(projected_auc) - summary += "\n\tDICE Label 0: {}".format(dice_label_0) - summary += "\n\tDICE Label 1: {}".format(dice_label_1) - summary += "\n\tDICE Label 2: {}".format(dice_label_2) - summary += "\n\tDICE Label 4: {}".format(dice_label_4) - if include_validation_with_hausdorff: - summary += "\n\tHausdorff95 Label 0: {}".format(hausdorff95_label_0) - summary += "\n\tHausdorff95 Label 1: {}".format(hausdorff95_label_1) - summary += "\n\tHausdorff95 Label 2: {}".format(hausdorff95_label_2) - summary += "\n\tHausdorff95 Label 4: {}".format(hausdorff95_label_4) - - - experiment_results['round'].append(round_num) - experiment_results['time'].append(total_simulated_time) - experiment_results['convergence_score'].append(projected_auc) - experiment_results['round_dice'].append(round_dice) - experiment_results['dice_label_0'].append(dice_label_0) - experiment_results['dice_label_1'].append(dice_label_1) - experiment_results['dice_label_2'].append(dice_label_2) - experiment_results['dice_label_4'].append(dice_label_4) - if include_validation_with_hausdorff: - experiment_results['hausdorff95_label_0'].append(hausdorff95_label_0) - experiment_results['hausdorff95_label_1'].append(hausdorff95_label_1) - experiment_results['hausdorff95_label_2'].append(hausdorff95_label_2) - experiment_results['hausdorff95_label_4'].append(hausdorff95_label_4) - logger.info(summary) - - if save_checkpoints: - logger.info(f'Saving checkpoint for round {round_num}') - logger.info(f'To resume from this checkpoint, set the restore_from_checkpoint_folder parameter to \'{checkpoint_folder}\'') - save_checkpoint(checkpoint_folder, aggregator, - collaborator_names, collaborators, - round_num, collaborator_time_stats, - total_simulated_time, best_dice, - best_dice_over_time_auc, - collaborators_chosen_each_round, - collaborator_times_per_round, - experiment_results, - summary) - - # if the total_simulated_time has exceeded the maximum time, we break - # in practice, this means that the previous round's model is the last model scored, - # so a long final round should not actually benefit the competitor, since that final - # model is never globally validated - if total_simulated_time > MAX_SIMULATION_TIME: - logger.info("Simulation time exceeded. Ending Experiment") - break - - # save the most recent aggregated model in native format to be copied over as best when appropriate - # (note this model has not been validated by the collaborators yet) - task_runner.rebuild_model(round_num, aggregator.last_tensor_dict, validation=True) - task_runner.save_native(f'checkpoint/{checkpoint_folder}/temp_model.pkl') - - - - return pd.DataFrame.from_dict(experiment_results), checkpoint_folder + exit("GANDLF config file not found. Exiting...") + + collaborators = [] + for idx, col in enumerate(collaborator_names): + col_dir = os.path.join(work, 'data', str(col)) + os.makedirs(col_dir, exist_ok=True) + + train_csv_path = os.path.join(col_dir, 'train.csv') + val_csv_path = os.path.join(col_dir, 'valid.csv') + + transformed_csv_dict[col]['train'].to_csv(train_csv_path) + transformed_csv_dict[col]['val'].to_csv(val_csv_path) + + collaborators.append( + Collaborator( + name=col, + private_attributes_callable=collaborator_private_attributes, + # If 1 GPU is available in the machine + # Set `num_gpus=0.0` to `num_gpus=0.3` to run on GPU + # with ray backend with 2 collaborators + num_cpus=4.0, + num_gpus=0.0, + # private arguments required to pass to callable + index=idx, + train_csv_path=train_csv_path, + val_csv_path=val_csv_path + ) + ) + + aggregator = Aggregator(name="aggregator", + private_attributes_callable=aggregator_private_attributes, + num_cpus=4.0, + num_gpus=0.0, + # private arguments required to pass to callable + collaborator_names=collaborator_names, + aggregation_type=aggregation_wrapper, + db_store_rounds=db_store_rounds) + + local_runtime = LocalRuntime( + aggregator=aggregator, collaborators=collaborators, backend=backend_process, num_actors=1 + ) + + logger.info(f"Local runtime collaborators = {local_runtime.collaborators}") + + params_dict = {"include_validation_with_hausdorff": include_validation_with_hausdorff, + "use_pretrained_model": use_pretrained_model, + "gandlf_config": gandlf_conf, + "choose_training_collaborators": choose_training_collaborators, + "training_hyper_parameters_for_round": training_hyper_parameters_for_round, + "restore_from_checkpoint_folder": restore_from_checkpoint_folder, + "save_checkpoints": save_checkpoints} + + model = FeTSChallengeModel() + flflow = FeTSFederatedFlow( + model, + params_dict, + rounds_to_train, + device + ) + + flflow.runtime = local_runtime + flflow.run() + return aggregator.private_attributes["checkpoint_folder"] \ No newline at end of file diff --git a/Task_1/fets_challenge/fets_challenge_model.py b/Task_1/fets_challenge/fets_challenge_model.py new file mode 100644 index 0000000..5767e10 --- /dev/null +++ b/Task_1/fets_challenge/fets_challenge_model.py @@ -0,0 +1,678 @@ +# Copyright 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""FeTS Challenge Model class for Federated Learning.""" + +import os +from copy import deepcopy +from typing import Union + +import numpy as np +import torch as pt +import yaml +from GANDLF.compute.forward_pass import validate_network +from GANDLF.compute.generic import create_pytorch_objects +from GANDLF.compute.training_loop import train_network +from GANDLF.config_manager import ConfigManager + +from openfl.federated.task.runner import TaskRunner +from openfl.utilities import TensorKey +from openfl.utilities.split import split_tensor_dict_for_holdouts +from logging import getLogger + +class FeTSChallengeModel(): + """FeTS Challenge Model class for Federated Learning. + + This class provides methods to manage and manipulate GaNDLF models in a + federated learning context. + + Attributes: + model (Model): The built model. + optimizer (Optimizer): Optimizer for the model. + scheduler (Scheduler): Scheduler for the model. + params (Parameters): Parameters for the model. + device (str): Device for the model. + training_round_completed (bool): Whether the training round has been + completed. + tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor + dict split function. + """ + + def __init__(self): + """Initializes the GaNDLFTaskRunner object. + + Sets up the initial state of the GaNDLFTaskRunner object, initializing + various components needed for the federated model. + """ + + self.model = None + self.optimizer = None + self.scheduler = None + self.params = None + self.device = None + + self.opt_treatment = "RESET" + + self.training_round_completed = False + self.logger = getLogger(__name__) + + # overwrite attribute to account for one optimizer param (in every + # child model that does not overwrite get and set tensordict) that is + # not a numpy array + self.tensor_dict_split_fn_kwargs = {} + self.tensor_dict_split_fn_kwargs.update({"holdout_tensor_names": ["__opt_state_needed"]}) + + def rebuild_model(self, input_tensor_dict, validation=False): + """Parse tensor names and update weights of model. Handles the + optimizer treatment. + + Args: + input_tensor_dict (dict): The input tensor dictionary used to + update the weights of the model. + validation (bool, optional): A flag indicating whether the model + is in validation. Defaults to False. + + Returns: + None + """ + if self.opt_treatment == "RESET": + self.reset_opt_vars() + self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) + elif ( + self.training_round_completed + and self.opt_treatment == "CONTINUE_GLOBAL" + and not validation + ): + self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) + else: + self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) + + def validate(self, col_name, round_num, input_tensor_dict, val_loader, use_tqdm=False, **kwargs): + """Validate. + Run validation of the model on the local data. + Args: + col_name (str): Name of the collaborator. + round_num (int): Current round number. + input_tensor_dict (dict): Required input tensors (for model). + val_loader (DataLoader): Validation data loader. + use_tqdm (bool, optional): Use tqdm to print a progress bar. + Defaults to False. + **kwargs: Key word arguments passed to GaNDLF main_run. + + Returns: + output_tensor_dict (dict): Tensors to send back to the aggregator. + {} (dict): Tensors to maintain in the local TensorDB. + """ + self.rebuild_model(input_tensor_dict, validation=True) + self.model.eval() + + epoch_valid_loss, epoch_valid_metric = validate_network( + self.model, + val_loader, + self.scheduler, + self.params, + round_num, + mode="validation", + ) + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + + output_tensor_dict = {} + output_tensor_dict[TensorKey('valid_loss', origin, round_num, True, tags)] = np.array(epoch_valid_loss) + for k, v in epoch_valid_metric.items(): + if isinstance(v, str): + v = list(map(float, v.split('_'))) + + if np.array(v).size == 1: + output_tensor_dict[TensorKey(f'valid_{k}', origin, round_num, True, tags)] = np.array(v) + else: + for idx,label in enumerate([0,1,2,4]): + output_tensor_dict[TensorKey(f'valid_{k}_{label}', origin, round_num, True, tags)] = np.array(v[idx]) + + # Empty list represents metrics that should only be stored locally + return output_tensor_dict, {} + + def inference(self, col_name, round_num, val_loader, use_tqdm=False, **kwargs): + """Inference. + Run validation of the model on the local data. + Args: + col_name (str): Name of the collaborator. + round_num (int): Current round number. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool, optional): Use tqdm to print a progress bar. + Defaults to False. + **kwargs: Key word arguments passed to GaNDLF main_run. + + Returns: + output_tensor_dict (dict): Tensors to send back to the aggregator. + {} (dict): Tensors to maintain in the local TensorDB. + """ + self.model.eval() + + epoch_inference_loss, epoch_inference_metric = validate_network( + self.model, + val_loader, + self.scheduler, + self.params, + round_num, + mode="inference", + ) + + origin = col_name + suffix = 'inference' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + + output_tensor_dict = {} + output_tensor_dict[TensorKey('inference_loss', origin, round_num, True, tags)] = np.array(epoch_inference_loss) + for k, v in epoch_inference_metric.items(): + if isinstance(v, str): + v = list(map(float, v.split('_'))) + + if np.array(v).size == 1: + output_tensor_dict[TensorKey(f'inference_{k}', origin, round_num, True, tags)] = np.array(v) + else: + for idx,label in enumerate([0,1,2,4]): + output_tensor_dict[TensorKey(f'inference_{k}_{label}', origin, round_num, True, tags)] = np.array(v[idx]) + + # Empty list represents metrics that should only be stored locally + return output_tensor_dict, {} + + def train(self, col_name, round_num, input_tensor_dict, hparams_dict, train_loader, use_tqdm=False, **kwargs): + """Train batches. + Train the model on the requested number of batches. + Args: + col_name (str): Name of the collaborator. + round_num (int): Current round number. + input_tensor_dict (dict): Required input tensors (for model). + use_tqdm (bool, optional): Use tqdm to print a progress bar. + Defaults to False. + epochs (int, optional): The number of epochs to train. Defaults to 1. + **kwargs: Key word arguments passed to GaNDLF main_run. + + Returns: + global_tensor_dict (dict): Tensors to send back to the aggregator. + local_tensor_dict (dict): Tensors to maintain in the local + TensorDB. + """ + # handle the hparams + epochs_per_round = int(hparams_dict.pop('epochs_per_round')) + learning_rate = float(hparams_dict.pop('learning_rate')) + + # set to "training" mode + self.model.train() + self.rebuild_model(input_tensor_dict) + + # Set the learning rate + self.logger.info(f"Setting learning rate to {learning_rate}") + for group in self.optimizer.param_groups: + group['lr'] = learning_rate + + for epoch in range(epochs_per_round): + print(f"Run {epoch} of {round_num}") + # FIXME: do we want to capture these in an array + # rather than simply taking the last value? + epoch_train_loss, epoch_train_metric = train_network( + self.model, + train_loader, + self.optimizer, + self.params, + ) + + # output model tensors (Doesn't include TensorKey) + tensor_dict = self.get_tensor_dict(self.model, with_opt_vars=True) + + metric_dict = {'loss': epoch_train_loss} + for k, v in epoch_train_metric.items(): + if isinstance(v, str): + v = list(map(float, v.split('_'))) + if np.array(v).size == 1: + metric_dict[f'train_{k}'] = np.array(v) + else: + for idx,label in enumerate([0,1,2,4]): + metric_dict[f'train_{k}_{label}'] = np.array(v[idx]) + + # Return global_tensor_dict, local_tensor_dict + # is this even pt-specific really? + global_tensor_dict, local_tensor_dict = create_tensorkey_dicts( + tensor_dict, + metric_dict, + col_name, + round_num, + self.logger, + self.tensor_dict_split_fn_kwargs, + ) + + # This will signal that the optimizer values are now present, + # and can be loaded when the model is rebuilt + self.training_round_completed = True + + # Return global_tensor_dict, local_tensor_dict + return global_tensor_dict, local_tensor_dict + + def get_tensor_dict(self, model=None, with_opt_vars=False): + """Return the tensor dictionary. + + Args: + with_opt_vars (bool): Return the tensor dictionary including the + optimizer tensors (Default=False). + + Returns: + state (dict): Tensor dictionary {**dict, **optimizer_dict} + """ + # Gets information regarding tensor model layers and optimizer state. + # FIXME: self.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names which is good + # FIXME: do both and sanity check each time? + + if model is None: + model = self.model + + state = to_cpu_numpy(model.state_dict()) + + if with_opt_vars: + opt_state = _get_optimizer_state(self.optimizer) + state = {**state, **opt_state} + + return state + + def _get_weights_names(self, with_opt_vars=False): + """Get the names of the weights. + + Args: + with_opt_vars (bool, optional): Include the optimizer variables. + Defaults to False. + + Returns: + list: List of weight names. + """ + # Gets information regarding tensor model layers and optimizer state. + # FIXME: self.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names which is good + # FIXME: do both and sanity check each time? + + state = self.model.state_dict().keys() + + if with_opt_vars: + opt_state = _get_optimizer_state(self.model.optimizer) + state += opt_state.keys() + + return state + + def set_tensor_dict(self, tensor_dict, with_opt_vars=False): + """Set the tensor dictionary. + + Args: + tensor_dict (dict): The tensor dictionary. + with_opt_vars (bool, optional): Include the optimizer tensors. + Defaults to False. + """ + set_pt_model_from_tensor_dict(self.model, tensor_dict, self.device, with_opt_vars) + + def get_optimizer(self): + """Get the optimizer of this instance. + + Returns: + Optimizer: The optimizer of this instance. + """ + return self.optimizer + + def load_native( + self, + filepath, + model_state_dict_key="model_state_dict", + optimizer_state_dict_key="optimizer_state_dict", + **kwargs, + ): + """ + Load model and optimizer states from a pickled file specified by \ + filepath. model_/optimizer_state_dict args can be specified if needed. \ + Uses pt.load(). + + Args: + filepath (str): Path to pickle file created by pt.save(). + model_state_dict_key (str, optional): Key for model state dict in + pickled file. Defaults to 'model_state_dict'. + optimizer_state_dict_key (str, optional): Key for optimizer state + dict in picked file. Defaults to 'optimizer_state_dict'. + **kwargs: Additional keyword arguments. + """ + pickle_dict = pt.load(filepath) + self.model.load_state_dict(pickle_dict[model_state_dict_key]) + self.optimizer.load_state_dict(pickle_dict[optimizer_state_dict_key]) + + def save_native( + self, + filepath, + model_state_dict_key="model_state_dict", + optimizer_state_dict_key="optimizer_state_dict", + **kwargs, + ): + """ + Save model and optimizer states in a picked file specified by the \ + filepath. model_/optimizer_state_dicts are stored in the keys provided. \ + Uses pt.save(). + + Args: + filepath (str): Path to pickle file to be created by pt.save(). + model_state_dict_key (str, optional): Key for model state dict in + pickled file. Defaults to 'model_state_dict'. + optimizer_state_dict_key (str, optional): Key for optimizer state + dict in picked file. Defaults to 'optimizer_state_dict'. + **kwargs: Additional keyword arguments. + """ + + pickle_dict = { + model_state_dict_key: self.model.state_dict(), + optimizer_state_dict_key: self.optimizer.state_dict(), + } + pt.save(pickle_dict, filepath) + + def reset_opt_vars(self): + """Reset optimizer variables.""" + pass + + +def create_tensorkey_dicts( + tensor_dict, + metric_dict, + col_name, + round_num, + logger, + tensor_dict_split_fn_kwargs, +): + """Create dictionaries of TensorKeys for global and local tensors. + + Args: + tensor_dict (dict): Dictionary of tensors. + metric_dict (dict): Dictionary of metrics. + col_name (str): Name of the collaborator. + round_num (int): Current round number. + logger (Logger): Logger instance. + tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor + dict split function. + + Returns: + global_tensor_dict (dict): Dictionary of global TensorKeys. + local_tensor_dict (dict): Dictionary of local TensorKeys. + """ + origin = col_name + tags = ("trained",) + output_metric_dict = {} + for k, v in metric_dict.items(): + tk = TensorKey(k, origin, round_num, True, ("metric",)) + output_metric_dict[tk] = np.array(v) + + global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( + logger, tensor_dict, **tensor_dict_split_fn_kwargs + ) + + # Create global tensorkeys + global_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in global_model_dict.items() + } + # Create tensorkeys that should stay local + local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num, False, tags): nparray + for tensor_name, nparray in local_model_dict.items() + } + # The train/validate aggregated function of the next round will look + # for the updated model parameters. + # This ensures they will be resolved locally + next_local_tensorkey_model_dict = { + TensorKey(tensor_name, origin, round_num + 1, False, ("model",)): nparray + for tensor_name, nparray in local_model_dict.items() + } + + global_tensor_dict = {**output_metric_dict, **global_tensorkey_model_dict} + local_tensor_dict = { + **local_tensorkey_model_dict, + **next_local_tensorkey_model_dict, + } + + return global_tensor_dict, local_tensor_dict + + +def set_pt_model_from_tensor_dict(model, tensor_dict, device, with_opt_vars=False): + """Set the tensor dictionary for the PyTorch model. + + Args: + model (Model): The PyTorch model. + tensor_dict (dict): Tensor dictionary. + device (str): Device for the model. + with_opt_vars (bool, optional): Include the optimizer tensors. + Defaults to False. + """ + # Sets tensors for model layers and optimizer state. + # FIXME: model.parameters() instead? Unclear if load_state_dict() or + # simple assignment is better + # for now, state dict gives us names, which is good + # FIXME: do both and sanity check each time? + + new_state = {} + # Grabbing keys from model's state_dict helps to confirm we have + # everything + for k in model.state_dict(): + #print(f" Fetching state for key = {k} Value : {tensor_dict[k]}") + new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device) + + # set model state + model.load_state_dict(new_state) + + if with_opt_vars: + # see if there is state to restore first + if tensor_dict.pop("__opt_state_needed") == "true": + _set_optimizer_state(model.get_optimizer(), device, tensor_dict) + + # sanity check that we did not record any state that was not used + assert len(tensor_dict) == 0 + + +def _derive_opt_state_dict(opt_state_dict): + """Separate optimizer tensors from the tensor dictionary. + + Flattens the optimizer state dict so as to have key, value pairs with + values as numpy arrays. + The keys have sufficient info to restore opt_state_dict using + expand_derived_opt_state_dict. + + Args: + opt_state_dict (dict): Optimizer state dictionary. + + Returns: + derived_opt_state_dict (dict): Optimizer state dictionary. + """ + derived_opt_state_dict = {} + + # Determine if state is needed for this optimizer. + if len(opt_state_dict["state"]) == 0: + derived_opt_state_dict["__opt_state_needed"] = "false" + return derived_opt_state_dict + + derived_opt_state_dict["__opt_state_needed"] = "true" + + # Using one example state key, we collect keys for the corresponding + # dictionary value. + example_state_key = opt_state_dict["param_groups"][0]["params"][0] + example_state_subkeys = set(opt_state_dict["state"][example_state_key].keys()) + + # We assume that the state collected for all params in all param groups is + # the same. + # We also assume that whether or not the associated values to these state + # subkeys is a tensor depends only on the subkey. + # Using assert statements to break the routine if these assumptions are + # incorrect. + for state_key in opt_state_dict["state"].keys(): + assert example_state_subkeys == set(opt_state_dict["state"][state_key].keys()) + for state_subkey in example_state_subkeys: + assert isinstance( + opt_state_dict["state"][example_state_key][state_subkey], + pt.Tensor, + ) == isinstance(opt_state_dict["state"][state_key][state_subkey], pt.Tensor) + + state_subkeys = list(opt_state_dict["state"][example_state_key].keys()) + + # Tags will record whether the value associated to the subkey is a + # tensor or not. + state_subkey_tags = [] + for state_subkey in state_subkeys: + if isinstance(opt_state_dict["state"][example_state_key][state_subkey], pt.Tensor): + state_subkey_tags.append("istensor") + else: + state_subkey_tags.append("") + state_subkeys_and_tags = list(zip(state_subkeys, state_subkey_tags)) + + # Forming the flattened dict, using a concatenation of group index, + # subindex, tag, and subkey inserted into the flattened dict key - + # needed for reconstruction. + nb_params_per_group = [] + for group_idx, group in enumerate(opt_state_dict["param_groups"]): + for idx, param_id in enumerate(group["params"]): + for subkey, tag in state_subkeys_and_tags: + if tag == "istensor": + new_v = opt_state_dict["state"][param_id][subkey].cpu().numpy() + else: + new_v = np.array([opt_state_dict["state"][param_id][subkey]]) + derived_opt_state_dict[f"__opt_state_{group_idx}_{idx}_{tag}_{subkey}"] = new_v + nb_params_per_group.append(idx + 1) + # group lengths are also helpful for reconstructing + # original opt_state_dict structure + derived_opt_state_dict["__opt_group_lengths"] = np.array(nb_params_per_group) + + return derived_opt_state_dict + + +def expand_derived_opt_state_dict(derived_opt_state_dict, device): + """Expand the optimizer state dictionary. + + Takes a derived opt_state_dict and creates an opt_state_dict suitable as + input for load_state_dict for restoring optimizer state. + Reconstructing state_subkeys_and_tags using the example key prefix, + "__opt_state_0_0_", certain to be present. + + Args: + derived_opt_state_dict (dict): Derived optimizer state dictionary. + device (str): Device for the model. + + Returns: + opt_state_dict (dict): Expanded optimizer state dictionary. + """ + state_subkeys_and_tags = [] + for key in derived_opt_state_dict: + if key.startswith("__opt_state_0_0_"): + stripped_key = key[16:] + if stripped_key.startswith("istensor_"): + this_tag = "istensor" + subkey = stripped_key[9:] + else: + this_tag = "" + subkey = stripped_key[1:] + state_subkeys_and_tags.append((subkey, this_tag)) + + opt_state_dict = {"param_groups": [], "state": {}} + nb_params_per_group = list(derived_opt_state_dict.pop("__opt_group_lengths").astype(np.int32)) + + # Construct the expanded dict. + for group_idx, nb_params in enumerate(nb_params_per_group): + these_group_ids = [f"{group_idx}_{idx}" for idx in range(nb_params)] + opt_state_dict["param_groups"].append({"params": these_group_ids}) + for this_id in these_group_ids: + opt_state_dict["state"][this_id] = {} + for subkey, tag in state_subkeys_and_tags: + flat_key = f"__opt_state_{this_id}_{tag}_{subkey}" + if tag == "istensor": + new_v = pt.from_numpy(derived_opt_state_dict.pop(flat_key)) + else: + # Here (for currrently supported optimizers) the subkey + # should be 'step' and the length of array should be one. + assert subkey == "step" + assert len(derived_opt_state_dict[flat_key]) == 1 + new_v = int(derived_opt_state_dict.pop(flat_key)) + opt_state_dict["state"][this_id][subkey] = new_v + + # sanity check that we did not miss any optimizer state + assert len(derived_opt_state_dict) == 0 + + return opt_state_dict + + +def _get_optimizer_state(optimizer): + """Get the state of the optimizer. + + Args: + optimizer (Optimizer): Optimizer. + + Returns: + derived_opt_state_dict (dict): State of the optimizer. + """ + opt_state_dict = deepcopy(optimizer.state_dict()) + + # Optimizer state might not have some parts representing frozen parameters + # So we do not synchronize them + param_keys_with_state = set(opt_state_dict["state"].keys()) + for group in opt_state_dict["param_groups"]: + local_param_set = set(group["params"]) + params_to_sync = local_param_set & param_keys_with_state + group["params"] = sorted(params_to_sync) + + derived_opt_state_dict = _derive_opt_state_dict(opt_state_dict) + + return derived_opt_state_dict + + +def _set_optimizer_state(optimizer, device, derived_opt_state_dict): + """Set the state of the optimizer. + + Args: + optimizer (Optimizer): Optimizer. + device (str): Device for the model. + derived_opt_state_dict (dict): Derived optimizer state dictionary. + """ + temp_state_dict = expand_derived_opt_state_dict(derived_opt_state_dict, device) + + # FIXME: Figure out whether or not this breaks learning rate + # scheduling and the like. + # Setting default values. + # All optimizer.defaults are considered as not changing over course of + # training. + for group in temp_state_dict["param_groups"]: + for k, v in optimizer.defaults.items(): + group[k] = v + + optimizer.load_state_dict(temp_state_dict) + + +def to_cpu_numpy(state): + """Convert state to CPU as Numpy array. + + Args: + state (State): State to be converted. + + Returns: + state (dict): State as Numpy array. + """ + # deep copy so as to decouple from active model + state = deepcopy(state) + + for k, v in state.items(): + # When restoring, we currently assume all values are tensors. + if not pt.is_tensor(v): + raise ValueError( + "We do not currently support non-tensors " "coming from model.state_dict()" + ) + # get as a numpy array, making sure is on cpu + state[k] = v.cpu().numpy() + return state diff --git a/Task_1/fets_challenge/fets_data_loader.py b/Task_1/fets_challenge/fets_data_loader.py new file mode 100644 index 0000000..e69e004 --- /dev/null +++ b/Task_1/fets_challenge/fets_data_loader.py @@ -0,0 +1,55 @@ +class FeTSDataLoader(): + """ + A data loader class for the FeTS challenge that handles training and validation data loaders. + + Attributes: + train_loader (DataLoader): The data loader for the training dataset. + valid_loader (DataLoader): The data loader for the validation dataset. + """ + + def __init__(self, train_loader, valid_loader): + """ + Initializes the FeTSDataLoader with training and validation data loaders. + + Args: + train_loader (DataLoader): The data loader for the training dataset. + valid_loader (DataLoader): The data loader for the validation dataset. + """ + self.train_loader = train_loader + self.valid_loader = valid_loader + + def get_train_loader(self): + """ + Returns the data loader for the training dataset. + + Returns: + DataLoader: The data loader for the training dataset. + """ + return self.train_loader + + def get_valid_loader(self): + """ + Returns the data loader for the validation dataset. + + Returns: + DataLoader: The data loader for the validation dataset. + """ + return self.valid_loader + + def get_train_data_size(self): + """ + Returns the size of the training dataset. + + Returns: + int: The number of samples in the training dataset. + """ + return len(self.train_loader.dataset) + + def get_valid_data_size(self): + """ + Returns the size of the validation dataset. + + Returns: + int: The number of samples in the validation dataset. + """ + return len(self.valid_loader.dataset) \ No newline at end of file diff --git a/Task_1/fets_challenge/fets_flow.py b/Task_1/fets_challenge/fets_flow.py new file mode 100644 index 0000000..208753d --- /dev/null +++ b/Task_1/fets_challenge/fets_flow.py @@ -0,0 +1,521 @@ + +"""FeTS Federated Flow.""" + +import os +import shutil +import time +import logging +from copy import deepcopy +import pandas as pd +from pathlib import Path +import torch + +from openfl.experimental.workflow.interface import FLSpec +from openfl.experimental.workflow.placement import aggregator, collaborator +from openfl.databases import TensorDB +from openfl.utilities import TensorKey, change_tags + +from .fets_data_loader import FeTSDataLoader + +from .checkpoint_utils import setup_checkpoint_folder, save_checkpoint, load_checkpoint +from .time_utils import gen_collaborator_time_stats, compute_times_per_collaborator, MAX_SIMULATION_TIME + +from GANDLF.compute.generic import create_pytorch_objects + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# [TODO] - FixMe Dataloaders cannot be passed as private attributes of collaborator. +# This is a temporary workaround to store the dataloaders in a global variable. +collaborator_data_loaders = {} + +class FeTSFederatedFlow(FLSpec): + def __init__(self, fets_model, params_dict, rounds=5 , device="cpu", **kwargs): + super().__init__(**kwargs) + self.fets_model = fets_model + self.n_rounds = rounds + self.device = device + self.current_round = 0 + self.total_simulated_time = 0 + self.best_dice = -1.0 + self.best_dice_over_time_auc = 0 + self.collaborators_chosen_each_round = {} + self.collaborator_times_per_round = {} + self.agg_tensor_dict = {} + self.restored = False + + self.include_validation_with_hausdorff = params_dict.get('include_validation_with_hausdorff', False) + self.use_pretrained_model = params_dict.get('use_pretrained_model', False) + self.choose_training_collaborators = params_dict.get('choose_training_collaborators', None) + self.training_hyper_parameters_for_round = params_dict.get('training_hyper_parameters_for_round', None) + self.restore_from_checkpoint_folder = params_dict.get('restore_from_checkpoint_folder', None) + self.save_checkpoints = params_dict.get('save_checkpoints', False) + + # GaNDLF config + self.gandlf_config = params_dict.get('gandlf_config', None) + + self.experiment_results = { + 'round':[], + 'time': [], + 'convergence_score': [], + 'round_dice': [], + 'dice_label_0': [], + 'dice_label_1': [], + 'dice_label_2': [], + 'dice_label_4': [], + } + + def _get_metric(self, metric_name, fl_round, agg_tensor_db): + tensor_key = TensorKey(metric_name, 'aggregator', fl_round, True, ('metric', 'validate_agg')) + return agg_tensor_db.get_tensor_from_cache(tensor_key).item() + + def _cache_tensor_dict(self, tensor_dict, agg_tensor_db, idx, agg_out_dict): + agg_out_dict.update({ + TensorKey( + tensor_name=key.tensor_name, + origin="aggregator", + round_number=key.round_number, + report=key.report, + tags=change_tags(key.tags, add_field=str(idx + 1)) + ): value + for key, value in tensor_dict.items() + }) + # Cache the updated dictionary in agg_tensor_db + agg_tensor_db.cache_tensor(agg_out_dict) + + def _get_aggregated_dict_with_tensorname(self, agg_tensor_dict, current_round=0, lookup_tags='aggregated'): + return { + tensor_key.tensor_name: value + for tensor_key, value in agg_tensor_dict.items() + if lookup_tags in tensor_key.tags + } + + def _update_metrics(self, current_round, agg_tensor_db, experiment_results, include_validation_with_hausdorff, + total_simulated_time, projected_auc): + + dice_metrics = [ + 'valid_loss', 'valid_dice', + 'valid_dice_per_label_0', 'valid_dice_per_label_1', + 'valid_dice_per_label_2', 'valid_dice_per_label_4' + ] + hausdorff_metrics = [ + 'valid_hd95_per_label_0', 'valid_hd95_per_label_1', + 'valid_hd95_per_label_2', 'valid_hd95_per_label_4' + ] + + # Fetch dice metrics + dice_values = {metric: self._get_metric(metric, current_round, agg_tensor_db) for metric in dice_metrics} + + # Fetch Hausdorff metrics if required + hausdorff_values = {} + if include_validation_with_hausdorff: + hausdorff_values = {metric: self._get_metric(metric, current_round, agg_tensor_db) for metric in hausdorff_metrics} + + # # End of round summary + summary = '"**** END OF ROUND {} SUMMARY *****"'.format(current_round) + summary += "\n\tSimulation Time: {} minutes".format(round(total_simulated_time / 60, 2)) + summary += "\n\t(Projected) Convergence Score: {}".format(projected_auc) + summary += "\n\tRound Loss: {}".format(dice_values['valid_loss']) + summary += "\n\tRound Dice: {}".format(dice_values['valid_dice']) + summary += "\n\tDICE Label 0: {}".format(dice_values['valid_dice_per_label_0']) + summary += "\n\tDICE Label 1: {}".format(dice_values['valid_dice_per_label_1']) + summary += "\n\tDICE Label 2: {}".format(dice_values['valid_dice_per_label_2']) + summary += "\n\tDICE Label 4: {}".format(dice_values['valid_dice_per_label_4']) + if include_validation_with_hausdorff: + summary += "\n\tHausdorff95 Label 0: {}".format(hausdorff_values['valid_hd95_per_label_0']) + summary += "\n\tHausdorff95 Label 1: {}".format(hausdorff_values['valid_hd95_per_label_1']) + summary += "\n\tHausdorff95 Label 2: {}".format(hausdorff_values['valid_hd95_per_label_2']) + summary += "\n\tHausdorff95 Label 4: {}".format(hausdorff_values['valid_hd95_per_label_4']) + logger.info(summary) + + experiment_results['round'].append(current_round) + experiment_results['time'].append(total_simulated_time) + experiment_results['convergence_score'].append(projected_auc) + experiment_results['round_dice'].append(dice_values['valid_dice']) + experiment_results['dice_label_0'].append(dice_values['valid_dice_per_label_0']) + experiment_results['dice_label_1'].append(dice_values['valid_dice_per_label_1']) + experiment_results['dice_label_2'].append(dice_values['valid_dice_per_label_2']) + experiment_results['dice_label_4'].append(dice_values['valid_dice_per_label_4']) + if include_validation_with_hausdorff: + experiment_results['hausdorff95_label_0'].append(hausdorff_values['valid_hd95_per_label_0']) + experiment_results['hausdorff95_label_1'].append(hausdorff_values['valid_hd95_per_label_1']) + experiment_results['hausdorff95_label_2'].append(hausdorff_values['valid_hd95_per_label_2']) + experiment_results['hausdorff95_label_4'].append(hausdorff_values['valid_hd95_per_label_4']) + + return summary, dice_values['valid_dice'] + + def _initialize_aggregator_model(self): + """Initialize the aggregator model and its components.""" + model, optimizer, _, _, scheduler, params = create_pytorch_objects( + self.gandlf_config, None, None, device=self.device + ) + self.fets_model.model = model + self.fets_model.optimizer = optimizer + self.fets_model.scheduler = scheduler + self.fets_model.params = params + + def _restore_from_checkpoint(self): + """Restore the experiment state from a checkpoint.""" + checkpoint_path = Path(f'checkpoint/{self.restore_from_checkpoint_folder}') + if not checkpoint_path.exists(): + logger.warning(f'Could not find provided checkpoint folder: {self.restore_from_checkpoint_folder}. Exiting...') + exit(1) + + logger.info(f'Attempting to load last completed round from {self.restore_from_checkpoint_folder}') + state = load_checkpoint(self.restore_from_checkpoint_folder) + self.checkpoint_folder = self.restore_from_checkpoint_folder + + ( + loaded_collaborator_names, starting_round_num, self.collaborator_time_stats, + self.total_simulated_time, self.best_dice, self.best_dice_over_time_auc, + self.collaborators_chosen_each_round, self.collaborator_times_per_round, + self.experiment_results, summary, agg_tensor_db + ) = state + + if loaded_collaborator_names != self.collaborator_names: + logger.error(f'Collaborator names found in checkpoint ({loaded_collaborator_names}) ' + f'do not match provided collaborators ({self.collaborator_names})') + exit(1) + + self.restored = True + logger.info(f'Previous summary for round {starting_round_num}') + logger.info(summary) + + # Update the agg_tensor_dict from stored tensor_db + self.current_round = starting_round_num + self._load_agg_tensor_dict(agg_tensor_db) + + def _setup_new_experiment(self): + """Set up a new experiment folder and initialize the tensor dictionary.""" + self.checkpoint_folder = setup_checkpoint_folder() + logger.info(f'\nCreated experiment folder {self.checkpoint_folder}...') + self.current_round = 0 + + # Initialize the tensor dictionary for the first round + tensor_dict = self.fets_model.get_tensor_dict() + self.agg_tensor_dict.update({ + TensorKey( + tensor_name=key, + origin='aggregator', + round_number=self.current_round, + report=False, + tags=('aggregated',) + ): value + for key, value in tensor_dict.items() + }) + + def _load_agg_tensor_dict(self, agg_tensor_db): + """Load the agg_tensor_dict from the stored tensor_db.""" + for _, record in agg_tensor_db.iterrows(): + tensor_key = TensorKey( + record["tensor_name"], record["origin"], record["round"], + record["report"], record["tags"] + ) + self.agg_tensor_dict[tensor_key] = record["nparray"] + + def _aggregate_tensors(self, agg_tensor_db, tensor_keys_per_col, collaborator_weight_dict): + """Aggregate tensors and cache the results.""" + self.aggregation_type.set_state_data_for_round(self.collaborators_chosen_each_round, self.collaborator_times_per_round) + for col, tensor_keys in tensor_keys_per_col.items(): + for tensor_key in tensor_keys: + tensor_name, origin, round_number, report, tags = tensor_key + if col in tags: + new_tags = change_tags(tags, remove_field=col) + agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags) + if agg_tensor_db.get_tensor_from_cache(agg_tensor_key) is None: + agg_results = agg_tensor_db.get_aggregated_tensor( + agg_tensor_key, + collaborator_weight_dict, + aggregation_function=self.aggregation_type, + ) + agg_tag_tk = TensorKey(tensor_name, origin, round_number, report, ('aggregated',)) + agg_tensor_db.cache_tensor({agg_tag_tk: agg_results}) + + def _process_collaborators(self, inputs, agg_tensor_db, collaborator_weights_unnormalized, times_per_collaborator): + """Process tensors for each collaborator and cache them.""" + tensor_keys_per_col = {} + for idx, col in enumerate(inputs): + agg_out_dict = {} + self._cache_tensor_dict(col.local_valid_dict, agg_tensor_db, idx, agg_out_dict) + self._cache_tensor_dict(col.agg_valid_dict, agg_tensor_db, idx, agg_out_dict) + self._cache_tensor_dict(col.global_output_tensor_dict, agg_tensor_db, idx, agg_out_dict) + + # Store the keys for each collaborator + tensor_keys_per_col[str(idx + 1)] = list(agg_out_dict.keys()) + collaborator_weights_unnormalized[col.input] = col.collaborator_task_weight + times_per_collaborator[col.input] = col.times_per_collaborator + return tensor_keys_per_col + + def _update_best_model(self, round_dice): + """Update the best model if the current round's dice score is better.""" + if self.best_dice < round_dice: + self.best_dice = round_dice + if self.current_round == 0: + logger.info(f'Skipping best model saving to disk as it is a random initialization.') + elif not os.path.exists(f'checkpoint/{self.checkpoint_folder}/temp_model.pkl'): + raise ValueError(f'Expected temporary model at: checkpoint/{self.checkpoint_folder}/temp_model.pkl to exist but it was not found.') + else: + shutil.copyfile( + src=f'checkpoint/{self.checkpoint_folder}/temp_model.pkl', + dst=f'checkpoint/{self.checkpoint_folder}/best_model.pkl' + ) + logger.info(f'Saved model with best average binary DICE: {self.best_dice} to checkpoint/{self.checkpoint_folder}/best_model.pkl') + + + def _update_aggregator_model(self, inputs): + """Update the aggregator model with the aggregated tensors.""" + logger.info(f'Aggregator Model updated for round {self.current_round}') + self.fets_model.model = inputs[0].fets_model.model + self.fets_model.optimizer = inputs[0].fets_model.optimizer + self.fets_model.scheduler = inputs[0].fets_model.scheduler + self.fets_model.params = inputs[0].fets_model.params + + # Rebuild the model with the aggregated tensor_dict + local_tensor_dict = self._get_aggregated_dict_with_tensorname(self.agg_tensor_dict, self.current_round) + self.fets_model.rebuild_model(local_tensor_dict) + self.fets_model.save_native(f'checkpoint/{self.checkpoint_folder}/temp_model.pkl') + + @aggregator + def start(self): + # Update experiment results if validation with Hausdorff is included + if self.include_validation_with_hausdorff: + self.experiment_results.update({ + f'hausdorff95_label_{label}': [] for label in [0, 1, 2, 4] + }) + + # Initialize the aggregator model + self._initialize_aggregator_model() + + self.collaborators = self.runtime.collaborators + # Handle checkpoint restoration or setup a new experiment folder + if self.restore_from_checkpoint_folder: + self._restore_from_checkpoint() + else: + self._setup_new_experiment() + + # Check if the experiment is already completed + if self.current_round >= self.n_rounds: + logger.info("Experiment already completed. Exiting...") + self.next(self.internal_loop) + return + + if self.restore_from_checkpoint_folder: + self.current_round += 1 + + # Proceed to the next step + self.collaborator_time_stats = gen_collaborator_time_stats(self.collaborator_names) + self.next(self.fetch_parameters_for_colls) + + @aggregator + def fetch_parameters_for_colls(self): + print("*" * 40) + print("Starting round {}".format(self.current_round)) + print("*" * 40) + hparams = self.training_hyper_parameters_for_round(self.collaborators, + None, + self.current_round, + self.collaborators_chosen_each_round, + self.collaborator_times_per_round) + + learning_rate, epochs_per_round = hparams + + if (epochs_per_round is None): + logger.warning('Hyper-parameter function warning: function returned None for "epochs_per_round". Setting "epochs_per_round" to 1') + epochs_per_round = 1 + + self.hparam_dict = {} + self.hparam_dict['learning_rate'] = learning_rate + self.hparam_dict['epochs_per_round'] = epochs_per_round + + logger.info(f'Hyperparameters for round {self.current_round}: {self.hparam_dict}') + + # pick collaborators to train for the round + self.training_collaborators = self.choose_training_collaborators(self.collaborator_names, + None, + self.current_round, + self.collaborators_chosen_each_round, + self.collaborator_times_per_round) + + logger.info('Collaborators chosen to train for round {}:\n\t{}'.format(self.current_round, self.training_collaborators)) + self.collaborators_chosen_each_round[self.current_round] = self.training_collaborators + + # Fetch the aggregated tensor dict for the current round + self.input_tensor_dict = self._get_aggregated_dict_with_tensorname(self.agg_tensor_dict, self.current_round) + if self.current_round == 0 or self.restored is True: + self.next(self.initialize_colls, foreach='collaborators') + self.restored = False + else: + self.next(self.aggregated_model_validation, foreach='training_collaborators') + + @collaborator + def initialize_colls(self): + if not self.include_validation_with_hausdorff: + self.gandlf_config['metrics'] = ['dice','dice_per_label'] + + logger.info(f'Initializing collaborator {self.input}') + ( + model, + optimizer, + train_loader, + val_loader, + scheduler, + params, + ) = create_pytorch_objects( + self.gandlf_config, train_csv=self.train_csv_path, val_csv=self.val_csv_path, device=self.device + ) + + self.fets_model.device = self.device + self.fets_model.model = model + self.fets_model.optimizer = optimizer + self.fets_model.scheduler = scheduler + self.fets_model.params = params + logger.info(f'Initializing dataloaders for collaborator {self.input}') + collaborator_data_loaders[self.input] = FeTSDataLoader(train_loader, val_loader) + + self.times_per_collaborator = compute_times_per_collaborator(self.input, + self.training_collaborators, + self.hparam_dict['epochs_per_round'], + collaborator_data_loaders[self.input], + self.collaborator_time_stats, + self.current_round) + + # [TODO] - FIX using Pretrained model + if self.use_pretrained_model: + if self.device == 'cpu': + checkpoint = torch.load(f'checkpoint/pretrained_model/resunet_pretrained.pth',map_location=torch.device('cpu')) + self.fets_model.model.load_state_dict(checkpoint['model_state_dict']) + self.fets_model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + else: + checkpoint = torch.load(f'checkpoint/pretrained_model/resunet_pretrained.pth') + self.fets_model.model.load_state_dict(checkpoint['model_state_dict']) + self.fets_model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + self.next(self.aggregated_model_validation) + + @collaborator + def aggregated_model_validation(self): + logger.info(f'Performing aggregated model validation for collaborator {self.input}') + input_tensor_dict = deepcopy(self.input_tensor_dict) + val_loader = collaborator_data_loaders[self.input].get_valid_loader() + self.agg_valid_dict, _ = self.fets_model.validate(self.input, self.current_round, input_tensor_dict, val_loader, apply="global") + self.next(self.train) + + @collaborator + def train(self): + logger.info(f'Performing training for collaborator {self.input}') + train_loader = collaborator_data_loaders[self.input].get_train_loader() + input_tensor_dict = deepcopy(self.input_tensor_dict) + self.global_output_tensor_dict, _ = self.fets_model.train(self.input, self.current_round, input_tensor_dict, self.hparam_dict, train_loader) + self.collaborator_task_weight = collaborator_data_loaders[self.input].get_train_data_size() + self.next(self.local_model_validation) + + @collaborator + def local_model_validation(self): + logger.info(f'Performing local model validation for collaborator {self.input}') + val_loader = collaborator_data_loaders[self.input].get_valid_loader() + # Update the model with the trained tensors for local validation of this round. + input_tensor_dict = self._get_aggregated_dict_with_tensorname(self.global_output_tensor_dict, self.current_round, 'trained') + self.local_valid_dict, _ = self.fets_model.validate(self.input, self.current_round, input_tensor_dict, val_loader, apply="local") + self.next(self.join) + + @aggregator + def join(self, inputs): + logger.info(f'Aggregating results for round {self.current_round}') + agg_tensor_db = TensorDB() # Used for aggregating and persisting tensors + collaborator_weights_unnormalized = {} + times_per_collaborator = {} + tensor_keys_per_col = () + + # Cache the aggregator tensor dict in tensor_db so that tensor_db has updated tensor values. + agg_tensor_db.cache_tensor(self.agg_tensor_dict) + + # Process each collaborator's tensors + tensor_keys_per_col = self._process_collaborators(inputs, agg_tensor_db, collaborator_weights_unnormalized, times_per_collaborator) + + self.collaborator_times_per_round[self.current_round] = times_per_collaborator + weight_total = sum(collaborator_weights_unnormalized.values()) + collaborator_weight_dict = { + k: v / weight_total for k, v in collaborator_weights_unnormalized.items() + } + logger.info(f'Calculated Collaborator weights: {collaborator_weight_dict} and and times: {times_per_collaborator}') + + # Perform aggregation + self._aggregate_tensors(agg_tensor_db, tensor_keys_per_col, collaborator_weight_dict) + + # Clean up the tensor_db for the round_data_to_delete rounds + agg_tensor_db.clean_up(self.db_store_rounds) + + times_list = [(t, col) for col, t in times_per_collaborator.items()] + times_list = sorted(times_list) + + # the round time is the max of the times_list + round_time = max([t for t, _ in times_list]) + self.total_simulated_time += round_time + + ## CONVERGENCE METRIC COMPUTATION + # update the auc score + self.best_dice_over_time_auc += self.best_dice * round_time + + # project the auc score as remaining time * best dice + # this projection assumes that the current best score is carried forward for the entire week + projected_auc = (MAX_SIMULATION_TIME - self.total_simulated_time) * self.best_dice + self.best_dice_over_time_auc + projected_auc /= MAX_SIMULATION_TIME + + # update metrics and results + summary, round_dice = self._update_metrics( + self.current_round, agg_tensor_db, self.experiment_results, + self.include_validation_with_hausdorff, self.total_simulated_time, projected_auc + ) + + # Update the best model if necessary + self._update_best_model(round_dice) + + # Update the agg_tensor_dict for subsequent rounds with the aggregated tensor_db + self.agg_tensor_dict.clear() + self.agg_tensor_dict = { + TensorKey(record["tensor_name"], record["origin"], record["round"], record["report"], record["tags"]): record["nparray"] + for _, record in agg_tensor_db.tensor_db.iterrows() + } + + if self.save_checkpoints: + logger.info(f'Saving checkpoint for round {self.current_round} : checkpoint folder {self.checkpoint_folder}') + logger.info(f'To resume from this checkpoint, set the restore_from_checkpoint_folder parameter to {self.checkpoint_folder}') + save_checkpoint(self.checkpoint_folder, agg_tensor_db, + self.collaborator_names, self.runtime.collaborators, + self.current_round, self.collaborator_time_stats, + self.total_simulated_time, self.best_dice, + self.best_dice_over_time_auc, + self.collaborators_chosen_each_round, + self.collaborator_times_per_round, + self.experiment_results, + summary) + + # if the total_simulated_time has exceeded the maximum time, we break + # in practice, this means that the previous round's model is the last model scored, + # so a long final round should not actually benefit the competitor, since that final + # model is never globally validated + if self.total_simulated_time > MAX_SIMULATION_TIME: + logger.info("Simulation time exceeded. Ending Experiment") + self.next(self.end) + return + + # Update the aggregator model and rebuild it with aggregated tensors + self._update_aggregator_model(inputs) + self.next(self.internal_loop) + + @aggregator + def internal_loop(self): + if self.current_round >= self.n_rounds: + print('************* EXPERIMENT COMPLETED *************') + print('Experiment results:') + print(pd.DataFrame.from_dict(self.experiment_results)) + self.next(self.end) + else: + self.current_round += 1 + self.next(self.fetch_parameters_for_colls) + + @aggregator + def end(self): + logger.info('********************************') + logger.info('End of flow') + logger.info('********************************') \ No newline at end of file diff --git a/Task_1/fets_challenge/inference.py b/Task_1/fets_challenge/inference.py index 13f0680..35b0706 100644 --- a/Task_1/fets_challenge/inference.py +++ b/Task_1/fets_challenge/inference.py @@ -15,6 +15,9 @@ import openfl.native as fx from .gandlf_csv_adapter import construct_fedsim_csv +from GANDLF.compute.generic import create_pytorch_objects +from GANDLF.config_manager import ConfigManager +from .fets_challenge_model import FeTSChallengeModel logger = getLogger(__name__) @@ -206,8 +209,6 @@ def model_outputs_to_disc(data_path, native_model_path, outputtag='', device='cpu'): - - fx.init('fets_challenge_workspace') from sys import path, exit @@ -215,37 +216,35 @@ def model_outputs_to_disc(data_path, root = file.parent.resolve() # interface root, containing command modules work = Path.cwd().resolve() - path.append(str(root)) - path.insert(0, str(work)) - generate_validation_csv(data_path,validation_csv, working_dir=work) - - overrides = { - 'task_runner.settings.device': device, - 'task_runner.settings.val_csv': 'validation_paths.csv', - 'task_runner.settings.train_csv': None, - } - - # Update the plan if necessary - plan = fx.update_plan(overrides) - plan.config['task_runner']['settings']['fets_config_dict']['save_output'] = True - plan.config['task_runner']['settings']['fets_config_dict']['output_dir'] = output_path + gandlf_config_path = os.path.join(root, 'config', 'gandlf_config.yaml') + fets_model = FeTSChallengeModel() + val_csv_path = os.path.join(work, 'validation_paths.csv') + gandlf_conf = ConfigManager(gandlf_config_path) + ( + model, + optimizer, + _, + val_loader, + scheduler, + params, + ) = create_pytorch_objects( + gandlf_conf, train_csv=None, val_csv=val_csv_path, device=device + ) + gandlf_conf['output_dir'] = output_path + gandlf_conf['save_output'] = True + fets_model.model = model + fets_model.optimizer = optimizer + fets_model.scheduler = scheduler + fets_model.params = params + fets_model.device = device - # overwrite datapath value for a single 'InferenceCol' collaborator - plan.cols_data_paths['InferenceCol'] = data_path - - # get the inference data loader - data_loader = copy(plan).get_data_loader('InferenceCol') - - # get the task runner, passing the data loader - task_runner = copy(plan).get_task_runner(data_loader) - # Populate model weights device = torch.device(device) - task_runner.load_native(filepath=native_model_path, map_location=device) - task_runner.opt_treatment = 'RESET' + fets_model.load_native(filepath=native_model_path, map_location=device) + #task_runner.opt_treatment = 'RESET' logger.info('Starting inference using data from {}\n'.format(data_path)) - task_runner.inference('aggregator',-1,task_runner.get_tensor_dict(),apply='global') + fets_model.inference('aggregator',-1,val_loader,apply='global') logger.info(f"\nFinished generating predictions to output folder {output_path}") diff --git a/Task_1/fets_challenge/time_utils.py b/Task_1/fets_challenge/time_utils.py new file mode 100644 index 0000000..95428ba --- /dev/null +++ b/Task_1/fets_challenge/time_utils.py @@ -0,0 +1,191 @@ +from collections import namedtuple +from logging import getLogger +import warnings + +import numpy as np +import pandas as pd + +## COLLABORATOR TIMING DISTRIBUTIONS +# These data are derived from the actual timing information in the real-world FeTS information +# They reflect a subset of the institutions involved. +# Tuples are (mean, stddev) in seconds + +# time to train one patient +TRAINING_TIMES = [(6.710741331207654, 0.8726112813698301), + (2.7343911917098445, 0.023976155580152165), + (3.173076923076923, 0.04154320960517865), + (6.580379746835443, 0.22461890673025595), + (3.452046783625731, 0.47136389322749656), + (6.090788461700995, 0.08541499003440205), + (3.206933911159263, 0.1927067498514361), + (3.3358208955223883, 0.2950567549663471), + (4.391304347826087, 0.37464538999161057), + (6.324805129494594, 0.1413885448869165), + (7.415133477633478, 1.1198881747151301), + (5.806410256410255, 0.029926699295169234), + (6.300204918032787, 0.24932319729777577), + (5.886317567567567, 0.018627858809133223), + (5.478184991273998, 0.04902740607167421), + (6.32440159574468, 0.15838847558954935), + (20.661918328585003, 6.085405543890793), + (3.197901325478645, 0.07049966132127056), + (6.523963730569948, 0.2533266757118492), + (2.6540077569489338, 0.025503099659276184), + (1.8025746183640918, 0.06805805332403576)] + +# time to validate one patient +VALIDATION_TIMES = [(23.129135113591072, 2.5975116854269507), + (12.965544041450777, 0.3476297824941513), + (14.782051282051283, 0.5262660449172765), + (16.444936708860762, 0.42613177203005187), + (15.728654970760235, 4.327559980390658), + (12.946098012884802, 0.2449927822869217), + (15.335950126991456, 1.1587597276712558), + (24.024875621890544, 3.087348297794285), + (38.361702127659576, 2.240113332190875), + (16.320970580839827, 0.4995108101783225), + (30.805555555555554, 3.1836337269688237), + (12.100899742930592, 0.41122386959584895), + (13.099897540983607, 0.6693132795197584), + (9.690202702702702, 0.17513593019922968), + (10.06980802792321, 0.7947848617875114), + (14.605333333333334, 0.6012305898922827), + (36.30294396961064, 9.24123672148819), + (16.9130060292851, 0.7452868131028928), + (40.244078460399706, 3.7700993678269037), + (13.161603102779575, 0.1975347910041472), + (11.222161868549701, 0.7021223062972527)] + +# time to download the model +DOWNLOAD_TIMES = [(112.42869743589742, 14.456734719659513), + (117.26870618556701, 12.549951446132013), + (13.059666666666667, 4.8700489616521185), + (47.50220338983051, 14.92128656898884), + (162.27864210526315, 32.562113378948396), + (99.46072058823529, 13.808785580783224), + (33.6347090909091, 25.00299299660141), + (216.25489393939392, 19.176465340447848), + (217.4117230769231, 20.757673955585453), + (98.38857297297298, 13.205048376808929), + (88.87509473684209, 23.152936862511545), + (66.96994262295081, 16.682497150763503), + (36.668852040816326, 13.759109844677598), + (149.31716326530614, 26.018185409516104), + (139.847, 80.04755583050091), + (54.97624444444445, 16.645170929316794)] + +# time to upload the model +UPLOAD_TIMES = [(192.28497409326425, 21.537450985376967), + (194.60103626943004, 24.194406902237056), + (20.0, 0.0), + (52.43859649122807, 5.047207127169352), + (182.82417582417582, 14.793519078918195), + (143.38059701492537, 7.910690646792151), + (30.695652173913043, 9.668122350904568), + (430.95360824742266, 54.97790476867727), + (348.3174603174603, 30.14347985347738), + (141.43715846994536, 5.271340868190727), + (158.7433155080214, 64.87526819391198), + (81.06086956521739, 7.003461202082419), + (32.60621761658031, 5.0418315093016615), + (281.5388601036269, 90.60338778706557), + (194.34065934065933, 36.6519776778435), + (66.53787878787878, 16.456280602190606)] + +logger = getLogger(__name__) +# This catches PyTorch UserWarnings for CPU +warnings.filterwarnings("ignore", category=UserWarning) + +# one week +# MINUTE = 60 +# HOUR = 60 * MINUTE +# DAY = 24 * HOUR +# WEEK = 7 * DAY +MAX_SIMULATION_TIME = 7 * 24 * 60 * 60 #TODO check if this can be move to time_utils.py file + +CollaboratorTimeStats = namedtuple('CollaboratorTimeStats', + [ + 'validation_mean', + 'training_mean', + 'download_speed_mean', + 'upload_speed_mean', + 'validation_std', + 'training_std', + 'download_speed_std', + 'upload_speed_std', + ] + ) + +def gen_collaborator_time_stats(collaborator_names, seed=0xFEEDFACE): + + np.random.seed(seed) + + stats = {} + for col in collaborator_names: + ml_index = np.random.randint(len(VALIDATION_TIMES)) + validation = VALIDATION_TIMES[ml_index] + training = TRAINING_TIMES[ml_index] + net_index = np.random.randint(len(DOWNLOAD_TIMES)) + download = DOWNLOAD_TIMES[net_index] + upload = UPLOAD_TIMES[net_index] + + stats[col] = CollaboratorTimeStats(validation_mean=validation[0], + training_mean=training[0], + download_speed_mean=download[0], + upload_speed_mean=upload[0], + validation_std=validation[1], + training_std=training[1], + download_speed_std=download[1], + upload_speed_std=upload[1]) + return stats + +def compute_times_per_collaborator(collaborator_name, + training_collaborators, + epochs_per_round, + collaborator_data, + collaborator_time_stats, + round_num): + np.random.seed(round_num) + time = 0 + + # stats + stats = collaborator_time_stats[collaborator_name] + + # download time + download_time = np.random.normal(loc=stats.download_speed_mean, + scale=stats.download_speed_std) + download_time = max(1, download_time) + time += download_time + + # validation time + data_size = collaborator_data.get_valid_data_size() + validation_time_per = np.random.normal(loc=stats.validation_mean, + scale=stats.validation_std) + validation_time_per = max(1, validation_time_per) + time += data_size * validation_time_per + + # only if training + if collaborator_name in training_collaborators: + # training time + data_size = collaborator_data.get_train_data_size() + training_time_per = np.random.normal(loc=stats.training_mean, + scale=stats.training_std) + training_time_per = max(1, training_time_per) + + # training data size depends on the hparams + data_size *= epochs_per_round + time += data_size * training_time_per + + # if training, we also validate the locally updated model + data_size = collaborator_data.get_valid_data_size() + validation_time_per = np.random.normal(loc=stats.validation_mean, + scale=stats.validation_std) + validation_time_per = max(1, validation_time_per) + time += data_size * validation_time_per + + # upload time + upload_time = np.random.normal(loc=stats.upload_speed_mean, + scale=stats.upload_speed_std) + upload_time = max(1, upload_time) + time += upload_time + return time \ No newline at end of file diff --git a/Task_1/generate_predictions.py b/Task_1/generate_predictions.py index 872a62a..b6d1900 100644 --- a/Task_1/generate_predictions.py +++ b/Task_1/generate_predictions.py @@ -12,10 +12,11 @@ from pathlib import Path import os from sys import path +from logging import getLogger from fets_challenge.gandlf_csv_adapter import construct_fedsim_csv, extract_csv_partitions device='cpu' - +logger = getLogger(__name__) # infer participant home folder home = str(Path.home()) @@ -23,13 +24,33 @@ # the data you want to run inference over checkpoint_folder='experiment_1' #data_path = -data_path = '/raid/datasets/FeTS22/MICCAI_FeTS2022_ValidationData' +data_path = '/home/brats/MICCAI_FeTS2022_ValidationData' + +working_directory= os.path.join(home, '.local/workspace/') + +try: + os.chdir(working_directory) + logger.info(f"Directory changed to : {os.getcwd()}") +except FileNotFoundError: + logger.info("Error: Directory not found.") +except PermissionError: + logger.info("Error: Permission denied") + +if checkpoint_folder is not None: + best_model_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'best_model.pkl') +else: + exit("No checkpoint folder found. Please provide a valid checkpoint folder. Exiting the experiment without inferencing") + +# If the experiment is only run for a single round, use the temp model instead +if not Path(best_model_path).exists(): + best_model_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'temp_model.pkl') + +if not Path(best_model_path).exists(): + exit("No model found. Please provide a valid checkpoint folder. Exiting the experiment without inferencing") -# you can keep these the same if you wish -best_model_path = os.path.join(home, '.local/workspace/checkpoint', checkpoint_folder, 'best_model.pkl') -outputs_path = os.path.join(home, '.local/workspace/checkpoint', checkpoint_folder, 'model_outputs') +outputs_path = os.path.join(working_directory, 'checkpoint', checkpoint_folder, 'model_outputs') -validation_csv_filename='validation.csv' +validation_csv_filename=os.path.join(home, '.local/workspace/', 'validation.csv') # Using this best model, we can now produce NIfTI files for model outputs diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/plan/cols.yaml b/Task_1/openfl-workspace/fets_challenge_workspace/plan/cols.yaml deleted file mode 100644 index ebd5bec..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/plan/cols.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# Provided by the FeTS Initiative (www.fets.ai) as part of the FeTS Challenge 2021 - -collaborators: \ No newline at end of file diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/plan/data.yaml b/Task_1/openfl-workspace/fets_challenge_workspace/plan/data.yaml deleted file mode 100644 index 93c8816..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/plan/data.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# Provided by the FeTS Initiative (www.fets.ai) as part of the FeTS Challenge 2021 - -one,1 -two,2 diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/plan/defaults b/Task_1/openfl-workspace/fets_challenge_workspace/plan/defaults deleted file mode 100644 index fb82f9c..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/plan/defaults +++ /dev/null @@ -1,2 +0,0 @@ -../../workspace/plan/defaults - diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/plan/plan.yaml b/Task_1/openfl-workspace/fets_challenge_workspace/plan/plan.yaml deleted file mode 100644 index ca4476c..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/plan/plan.yaml +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. - -aggregator : - defaults : plan/defaults/aggregator.yaml - template : openfl.component.Aggregator - settings : - init_state_path : save/fets_seg_test_init.pbuf - best_state_path : save/fets_seg_test_best.pbuf - last_state_path : save/fets_seg_test_last.pbuf - rounds_to_train : 3 - write_logs : true - - -collaborator : - defaults : plan/defaults/collaborator.yaml - template : openfl.component.Collaborator - settings : - delta_updates : false - opt_treatment : RESET - -data_loader : - defaults : plan/defaults/data_loader.yaml - template : openfl.federated.data.loader_fets_challenge.FeTSChallengeDataLoaderWrapper - settings : - feature_shape : [32, 32, 32] - -task_runner : - template : src.fets_challenge_model.FeTSChallengeModel - settings : - train_csv : seg_test_train.csv - val_csv : seg_test_val.csv - device : cpu - fets_config_dict : - batch_size: 1 - clip_grad: null - clip_mode: null - data_augmentation: {} - data_postprocessing: {} - data_preprocessing: - normalize: null - enable_padding: false - in_memory: false - inference_mechanism : - grid_aggregator_overlap: crop - patch_overlap: 0 - learning_rate: 0.001 - loss_function: dc - medcam_enabled: false - output_dir: '.' - metrics: - - dice - - dice_per_label - - hd95_per_label - model: - amp: true - architecture: resunet - base_filters: 32 - class_list: - - 0 - - 1 - - 2 - - 4 - dimension: 3 - final_layer: softmax - ignore_label_validation: null - norm_type: instance - nested_training: - testing: 1 - validation: -5 - num_epochs: 1 - optimizer: - type: sgd - parallel_compute_command: '' - patch_sampler: label - patch_size: - - 64 - - 64 - - 64 - patience: 100 - pin_memory_dataloader: false - print_rgb_label_warning: true - q_max_length: 100 - q_num_workers: 0 - q_samples_per_volume: 40 - q_verbose: false - save_output: false - save_training: false - scaling_factor: 1 - scheduler: - type: triangle_modified - track_memory_usage: false - verbose: false - version: - maximum: 0.0.14 - minimum: 0.0.14 - weighted_loss: true - - -network : - defaults : plan/defaults/network.yaml - -assigner: - template : src.challenge_assigner.FeTSChallengeAssigner - settings : - training_tasks : - - aggregated_model_validation - - train - - locally_tuned_model_validation - validation_tasks : - - aggregated_model_validation - -tasks : - aggregated_model_validation: - function : validate - kwargs : - apply : global - metrics : - - valid_loss - - valid_dice - - locally_tuned_model_validation: - function : validate - kwargs : - apply: local - metrics : - - valid_loss - - valid_dice - - train: - function : train - kwargs : - metrics : - - loss - - train_dice - epochs : 1 - - -compression_pipeline : - defaults : plan/defaults/compression_pipeline.yaml diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/requirements.txt b/Task_1/openfl-workspace/fets_challenge_workspace/requirements.txt deleted file mode 100644 index 9a7d57c..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torchvision -torch diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/src/__init__.py b/Task_1/openfl-workspace/fets_challenge_workspace/src/__init__.py deleted file mode 100644 index 1c5a549..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/src/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Provided by the FeTS Initiative (www.fets.ai) as part of the FeTS Challenge 2021 - -# Contributing Authors (alphabetical): -# Patrick Foley (Intel) -# Micah Sheller (Intel) - -TRAINING_HPARAMS = [ - 'epochs_per_round', - 'learning_rate', -] diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/src/challenge_assigner.py b/Task_1/openfl-workspace/fets_challenge_workspace/src/challenge_assigner.py deleted file mode 100644 index 46e847f..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/src/challenge_assigner.py +++ /dev/null @@ -1,40 +0,0 @@ -# Provided by the FeTS Initiative (www.fets.ai) as part of the FeTS Challenge 2022 - -# Contributing Authors (alphabetical): -# Micah Sheller (Intel) - -class FeTSChallengeAssigner: - def __init__(self, tasks, authorized_cols, training_tasks, validation_tasks, **kwargs): - """Initialize.""" - self.training_collaborators = [] - self.tasks = tasks - self.training_tasks = training_tasks - self.validation_tasks = validation_tasks - self.collaborators = authorized_cols - - def set_training_collaborators(self, training_collaborators): - self.training_collaborators = training_collaborators - - - def get_tasks_for_collaborator(self, collaborator_name, round_number): - """Get tasks for the collaborator specified.""" - if collaborator_name in self.training_collaborators: - return self.training_tasks - else: - return self.validation_tasks - - def get_collaborators_for_task(self, task_name, round_number): - """Get collaborators for the task specified.""" - if task_name in self.validation_tasks: - return self.collaborators - else: - return self.training_collaborators - - def get_all_tasks_for_round(self, round_number): - return self.training_tasks - - def get_aggregation_type_for_task(self, task_name): - """Extract aggregation type from self.tasks.""" - if 'aggregation_type' not in self.tasks[task_name]: - return None - return self.tasks[task_name]['aggregation_type'] diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/src/fets_challenge_model.py b/Task_1/openfl-workspace/fets_challenge_workspace/src/fets_challenge_model.py deleted file mode 100644 index 3794be6..0000000 --- a/Task_1/openfl-workspace/fets_challenge_workspace/src/fets_challenge_model.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""GaNDLFTaskRunner module.""" - -from copy import deepcopy - -import numpy as np -import torch as pt - -from openfl.utilities import split_tensor_dict_for_holdouts -from openfl.utilities import TensorKey - -from openfl.federated.task.runner_fets_challenge import * - -from GANDLF.compute.generic import create_pytorch_objects -from GANDLF.compute.training_loop import train_network -from GANDLF.compute.forward_pass import validate_network - -from . import TRAINING_HPARAMS - -class FeTSChallengeModel(FeTSChallengeTaskRunner): - """FeTSChallenge Model class for Federated Learning.""" - - def validate(self, col_name, round_num, input_tensor_dict, - use_tqdm=False, **kwargs): - """Validate. - Run validation of the model on the local data. - Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) - kwargs: Key word arguments passed to GaNDLF main_run - Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB - """ - self.rebuild_model(round_num, input_tensor_dict, validation=True) - self.model.eval() - # self.model.to(self.device) - - epoch_valid_loss, epoch_valid_metric = validate_network(self.model, - self.data_loader.val_dataloader, - self.scheduler, - self.params, - round_num, - mode="validation") - - self.logger.info(epoch_valid_loss) - self.logger.info(epoch_valid_metric) - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - - output_tensor_dict = {} - output_tensor_dict[TensorKey('valid_loss', origin, round_num, True, tags)] = np.array(epoch_valid_loss) - for k, v in epoch_valid_metric.items(): - if np.array(v).size == 1: - output_tensor_dict[TensorKey(f'valid_{k}', origin, round_num, True, tags)] = np.array(v) - else: - for idx,label in enumerate([0,1,2,4]): - output_tensor_dict[TensorKey(f'valid_{k}_{label}', origin, round_num, True, tags)] = np.array(v[idx]) - - return output_tensor_dict, {} - - def inference(self, col_name, round_num, input_tensor_dict, - use_tqdm=False, **kwargs): - """Inference. - Run inference of the model on the local data (used for final validation) - Args: - col_name: Name of the collaborator - round_num: What round is it - input_tensor_dict: Required input tensors (for model) - use_tqdm (bool): Use tqdm to print a progress bar (Default=True) - kwargs: Key word arguments passed to GaNDLF main_run - Returns: - global_output_dict: Tensors to send back to the aggregator - local_output_dict: Tensors to maintain in the local TensorDB - """ - self.rebuild_model(round_num, input_tensor_dict, validation=True) - self.model.eval() - # self.model.to(self.device) - - epoch_valid_loss, epoch_valid_metric = validate_network(self.model, - self.data_loader.val_dataloader, - self.scheduler, - self.params, - round_num, - mode="inference") - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - - output_tensor_dict = {} - output_tensor_dict[TensorKey('valid_loss', origin, round_num, True, tags)] = np.array(epoch_valid_loss) - for k, v in epoch_valid_metric.items(): - if np.array(v).size == 1: - output_tensor_dict[TensorKey(f'valid_{k}', origin, round_num, True, tags)] = np.array(v) - else: - for idx,label in enumerate([0,1,2,4]): - output_tensor_dict[TensorKey(f'valid_{k}_{label}', origin, round_num, True, tags)] = np.array(v[idx]) - - return output_tensor_dict, {} - - - def train(self, col_name, round_num, input_tensor_dict, use_tqdm=False, epochs=1, **kwargs): - """Train batches. - Train the model on the requested number of batches. - Args: - col_name : Name of the collaborator - round_num : What round is it - input_tensor_dict : Required input tensors (for model) - use_tqdm (bool) : Use tqdm to print a progress bar (Default=True) - epochs : The number of epochs to train - crossfold_test : Whether or not to use cross fold trainval/test - to evaluate the quality of the model under fine tuning - (this uses a separate prameter to pass in the data and - config used) - crossfold_test_data_csv : Data csv used to define data used in crossfold test. - This csv does not itself define the folds, just - defines the total data to be used. - crossfold_val_n : number of folds to use for the train,val level of the nested crossfold. - corssfold_test_n : number of folds to use for the trainval,test level of the nested crossfold. - kwargs : Key word arguments passed to GaNDLF main_run - Returns: - global_output_dict : Tensors to send back to the aggregator - local_output_dict : Tensors to maintain in the local TensorDB - """ - - # handle the hparams - epochs_per_round = int(input_tensor_dict.pop('epochs_per_round')) - learning_rate = float(input_tensor_dict.pop('learning_rate')) - - self.rebuild_model(round_num, input_tensor_dict) - # set to "training" mode - self.model.train() - - # Set the learning rate - for group in self.optimizer.param_groups: - group['lr'] = learning_rate - - for epoch in range(epochs_per_round): - self.logger.info(f'Run {epoch} epoch of {round_num} round') - # FIXME: do we want to capture these in an array rather than simply taking the last value? - epoch_train_loss, epoch_train_metric = train_network(self.model, - self.data_loader.train_dataloader, - self.optimizer, - self.params) - - # output model tensors (Doesn't include TensorKey) - tensor_dict = self.get_tensor_dict(with_opt_vars=True) - - metric_dict = {'loss': epoch_train_loss} - for k, v in epoch_train_metric.items(): - if np.array(v).size == 1: - metric_dict[f'train_{k}'] = np.array(v) - else: - for idx,label in enumerate([0,1,2,4]): - metric_dict[f'train_{k}_{label}'] = np.array(v[idx]) - - - # Return global_tensor_dict, local_tensor_dict - # is this even pt-specific really? - global_tensor_dict, local_tensor_dict = create_tensorkey_dicts(tensor_dict, - metric_dict, - col_name, - round_num, - self.logger, - self.tensor_dict_split_fn_kwargs) - - # Update the required tensors if they need to be pulled from the - # aggregator - # TODO this logic can break if different collaborators have different - # roles between rounds. - # For example, if a collaborator only performs validation in the first - # round but training in the second, it has no way of knowing the - # optimizer state tensor names to request from the aggregator because - # these are only created after training occurs. A work around could - # involve doing a single epoch of training on random data to get the - # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': - self.initialize_tensorkeys_for_functions(with_opt_vars=True) - - # This will signal that the optimizer values are now present, - # and can be loaded when the model is rebuilt - self.train_round_completed = True - - # Return global_tensor_dict, local_tensor_dict - return global_tensor_dict, local_tensor_dict - - def get_required_tensorkeys_for_function(self, func_name, **kwargs): - required = super().get_required_tensorkeys_for_function(func_name, **kwargs) - if func_name == 'train': - round_number = required[0].round_number - for hparam in TRAINING_HPARAMS: - required.append(TensorKey(tensor_name=hparam, origin='GLOBAL', round_number=round_number, report=False, tags=('hparam', 'model'))) - return required diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/partitioning_1.csv b/Task_1/partitioning_data/partitioning_1.csv similarity index 100% rename from Task_1/openfl-workspace/fets_challenge_workspace/partitioning_1.csv rename to Task_1/partitioning_data/partitioning_1.csv diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/partitioning_2.csv b/Task_1/partitioning_data/partitioning_2.csv similarity index 100% rename from Task_1/openfl-workspace/fets_challenge_workspace/partitioning_2.csv rename to Task_1/partitioning_data/partitioning_2.csv diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/small_split.csv b/Task_1/partitioning_data/small_split.csv similarity index 100% rename from Task_1/openfl-workspace/fets_challenge_workspace/small_split.csv rename to Task_1/partitioning_data/small_split.csv diff --git a/Task_1/openfl-workspace/fets_challenge_workspace/validation.csv b/Task_1/partitioning_data/validation.csv similarity index 100% rename from Task_1/openfl-workspace/fets_challenge_workspace/validation.csv rename to Task_1/partitioning_data/validation.csv diff --git a/Task_1/requirements.txt b/Task_1/requirements.txt new file mode 100644 index 0000000..cd27f40 --- /dev/null +++ b/Task_1/requirements.txt @@ -0,0 +1,12 @@ +chardet +charset-normalizer +dill==0.3.6 +matplotlib>=2.0.0 +metaflow==2.7.15 +nbdev==2.3.12 +nbformat==5.10.4 +ray==2.9.2 +tabulate==0.9.0 +torch==2.3.1 +torchvision==0.18.1 +fastcore==1.5.29 \ No newline at end of file diff --git a/Task_1/setup.py b/Task_1/setup.py index 1ff561d..0f38375 100644 --- a/Task_1/setup.py +++ b/Task_1/setup.py @@ -24,15 +24,14 @@ url='https://github.com/FETS-AI/Challenge', packages=[ 'fets_challenge', - 'openfl-workspace', ], include_package_data=True, install_requires=[ - 'openfl @ git+https://github.com/intel/openfl.git@f4b28d710e2be31cdfa7487fdb4e8cb3a1387a5f', - 'GANDLF @ git+https://github.com/CBICA/GaNDLF.git@e4d0d4bfdf4076130817001a98dfb90189956278', + 'openfl @ git+https://github.com/securefederatedai/openfl.git@v1.7.1', + 'GANDLF @ git+https://github.com/CBICA/GaNDLF.git@4d614fe1de550ea4035b543b4c712ad564248106', 'fets @ git+https://github.com/FETS-AI/Algorithms.git@fets_challenge', ], - python_requires='>=3.6, <3.9', + python_requires='>=3.10, <3.13', classifiers=[ 'Environment :: Console', # How mature is this project? Common values are @@ -46,9 +45,9 @@ 'License :: OSI Approved :: FETS UI License', # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ] )