This project investigates how Graph Neural Networks (GNNs) can improve the classification of Attention-Deficit/Hyperactivity Disorder (ADHD) in children using the National Survey of Children's Health (NSCH) 2022 dataset. The analysis compares multiple GNN architectures with traditional machine learning classifiers to assess their performance in a clinical classification setting.
This project aims to extract the most significant factors associated with ADHD in children and examine whether using a Graph Neural Network (GNN) to model relationships between nodes and update node features can enhance the accuracy of predicting ADHD diagnoses. Specifically, this study uses logistic regression to extract statistically significant features and compares the classification performance of three types of Graph Neural Networks models, GCN, GAT, and GraphSAGE with three machine learning model, Random Forest (RF), XGBoost and MLP.
Data from 37,291 children aged 3–17 was used, with 3,476 (9.32%) diagnosed with ADHD. Due to the severe class imbalance, stratified sampling and bootstrapping techniques were applied to construct balanced training sets. The study also assessed the impact of similarity metrics (Euclidean vs. Cosine) on graph construction.
network_project/
├── nsch_dataset/ # Processed data and model results
│ └── 2022/
│ ├── balanced/ # Balanced dataset and results
│ └── bootstrap/ # Bootstrap dataset and results
│
├── NSCH_National_Survey_of_Children_Health/ # Original dataset
│ └── nsch_2022_topical_SAS/ # Original and cleaned data files
│
└── nsch_project/ # All code files
├── Data_cleaning/ # Data preparation scripts
├── GNN/ # Graph Neural Network implementation
│ ├── GAT/ # Graph Attention Network models
│ ├── GCN/ # Graph Convolutional Network models
│ ├── graphSAGE/ # GraphSAGE models
│ └── imbalanced/ # GNNs on imbalanced dataset (not used)
├── MLP/ # Multilayer Perceptron models
├── RF/ # Random Forest models
└── XGBoost/ # XGBoost models
Script filenames follow these conventions:
- balanced: Uses balanced dataset (stratified sampling)
- bootstrap: Uses bootstrap dataset (sampling with replacement)
- Cosine/C: Uses cosine similarity for distance metric
- Euclidean/E: Uses Euclidean distance for distance metric
- pyg: Uses PyTorch Geometric library
- cv: Uses cross-validation
The project workflow includes:
- Variable Selection: Logistic regression was used to identify 24 statistically significant variables (p < 0.05) from 48 potential variables related to Child Health, Parental Health, Household Environment, and Birth Conditions.
- Data Preprocessing:
- Binary variables recoded (1=Yes, 0=No)
- Numerical variables normalized using MinMaxScaler
- Multicategorical variables converted to dummy variables
- Class Imbalance Management:
- Stratified sampling: All 3,476 ADHD cases and an equal number of non-ADHD cases
- Bootstrapping: Random sampling with replacement to create multiple balanced training sets
- Network Creation:
- Each child represented as a node in the graph
- Edges created based on feature similarity (Euclidean distance or Cosine similarity)
- Top 20% of most similar nodes connected
- Each node assigned a feature matrix of the 24 significant variables
- Model Training:
- GNN models: GCN, GAT, GraphSAGE (3 convolutional layers with Dropout)
- Traditional ML: RF, XGBoost, MLP (hyperparameters optimized via GridSearch)
- Performance Evaluation:
- Metrics: Accuracy, Precision, Recall, F1-score
- Special focus on Recall for ADHD cases to minimize false negatives
Install required packages:
pip install -r requirements.txtFollow these steps to reproduce the project's results:
-
Data Cleaning and Variable Selection:
cd nsch_project/Data_cleaning # Data selection and initial cleaning python data_select.py # Basic data preprocessing python data_process.py # Variable selection using logistic regression python variable_select_LR.py # Process selected variables python data_process_variable_select.py
-
Generate Datasets:
cd nsch_project/Data_cleaning # Create balanced dataset (stratified sampling) python generate_pyg_data.py # Create bootstrap datasets (sampling with replacement) python boostrap_pyg_data.py # Generate matrices for graph construction python generate_matrices.py # Generate datasets for ML models python generate_ml_data.py
-
Train and Test GNN Models:
cd nsch_project/GNN # GCN model cd GCN python train_gcn_pyg_balanced.py python test_gcn_pyg_balanced.py # GAT model cd ../GAT python train_gat_pyg_balanced.py python test_gat_pyg_balanced.py # GraphSAGE model cd ../graphSAGE python train_graphsage_pyg_balanced.py python test_graphsage_pyg_balanced.py
-
Train and Test Traditional ML Models:
# Random Forest models cd nsch_project/RF python rf_cv_balanced.py # XGBoost models cd ../XGBoost python xgboost_cv_balanced.py # On bootstrap datasets python xgboost_bootstrap.py # MLP models cd ../MLP python mlp_train_balanced.py python mlp_test_balanced.py
-
Boostrap:
# GraphSAGE models on bootstrap datasets cd nsch_project/GNN/graphSAGE python train_graphsage_pyg_balanced_bootstrap.py # XGBoost models on bootstrap datasets cd nsch_project/XGBoost python xgboost_cv_balanced_boostrap.py
Each step should be executed in sequence as the outputs from earlier steps are required for the later ones. Results will be saved in the respective model output directories.
-
Table 1. Performance Comparison of GNN Models with Stratified Sampling
-
Table 2. Performance Comparison of ML Models with Stratified Sampling
-
Key findings:
- GraphSAGE outperformed other models in terms of ADHD Recall (84.48%), showcasing its ability to accurately identify true ADHD cases.
- XGBoost excelled in Precision for ADHD (86.63%) .
- MLP showed exceptional performance for Non-ADHD detection (Recall of 95.02%).
- No single model consistently outperformed the others across all metrics.
This research utilizes the National Survey of Children's Health (NSCH) 2022 public dataset, which is administered by the Health Resources and Services Administration (HRSA) Maternal and Child Health Bureau (MCHB).
This project uses the NSCH public dataset and follows their terms of use.

