From 08208417a5653e93db796538785519f2d92c2478 Mon Sep 17 00:00:00 2001 From: Jorge Fabila Date: Wed, 18 Jun 2025 08:53:17 +0200 Subject: [PATCH] client id eliminado y n_features automatico --- client_cmd.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/client_cmd.py b/client_cmd.py index f130446..b783761 100644 --- a/client_cmd.py +++ b/client_cmd.py @@ -18,7 +18,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Reads parameters from command line.") - parser.add_argument("--client_id", type=int, default="Client Id", help="Number of client") + # parser.add_argument("--client_id", type=int, default="Client Id", help="Number of client") parser.add_argument("--dataset", type=str, default="dt4h_format", help="Dataloader to use") parser.add_argument("--metadata_file", type=str, default="metadata.json", help="Json file with metadata") parser.add_argument("--data_file", type=str, default="data.parquet" , help="parquet o csv file with actual data") @@ -53,9 +53,9 @@ config = vars(args) if config["model"] in ("logistic_regression", "elastic_net", "lsvc"): - print("LINEAR", config["model"], config["n_features"]) config["linear_models"] = {} - config['linear_models']['n_features'] = config["n_features"] + n_feats = len(config["train_labels"]) + config['linear_models']['n_features'] = n_feats # config["n_features"] config["held_out_center_id"] = -1 # Create sandbox log file path @@ -111,7 +111,6 @@ def flush(self): if config["production_mode"] == "True": node_name = os.getenv("NODE_NAME") # num_client = int(node_name.split("_")[-1]) - num_client = config["client_id"] data_path = os.getenv("DATA_PATH") ca_cert = Path(os.path.join(config["certs_path"],"rootCA_cert.pem")) root_certificate = Path(f"{ca_cert}").read_bytes() @@ -139,14 +138,11 @@ def flush(self): root_certificate = None central_ip = "LOCALHOST" central_port = config["local_port"] - num_client = config["client_id"] # if len(sys.argv) == 1: # raise ValueError("Please provide the client id when running in simulation mode") # num_client = int(sys.argv[1]) - - print("Client id:" + str(num_client)) - +num_client = 0 # config["client_id"] (X_train, y_train), (X_test, y_test) = datasets.load_dataset(config, num_client) data = (X_train, y_train), (X_test, y_test)