Skip to content

29xuan/NetworkADHD

Repository files navigation

NSCH ADHD Classification Project

This project investigates how Graph Neural Networks (GNNs) can improve the classification of Attention-Deficit/Hyperactivity Disorder (ADHD) in children using the National Survey of Children's Health (NSCH) 2022 dataset. The analysis compares multiple GNN architectures with traditional machine learning classifiers to assess their performance in a clinical classification setting.

Project Overview

This project aims to extract the most significant factors associated with ADHD in children and examine whether using a Graph Neural Network (GNN) to model relationships between nodes and update node features can enhance the accuracy of predicting ADHD diagnoses. Specifically, this study uses logistic regression to extract statistically significant features and compares the classification performance of three types of Graph Neural Networks models, GCN, GAT, and GraphSAGE with three machine learning model, Random Forest (RF), XGBoost and MLP.

Data from 37,291 children aged 3–17 was used, with 3,476 (9.32%) diagnosed with ADHD. Due to the severe class imbalance, stratified sampling and bootstrapping techniques were applied to construct balanced training sets. The study also assessed the impact of similarity metrics (Euclidean vs. Cosine) on graph construction.

Repository Structure

network_project/
├── nsch_dataset/                    # Processed data and model results
│   └── 2022/
│       ├── balanced/                # Balanced dataset and results
│       └── bootstrap/               # Bootstrap dataset and results
│
├── NSCH_National_Survey_of_Children_Health/   # Original dataset
│   └── nsch_2022_topical_SAS/       # Original and cleaned data files
│
└── nsch_project/                    # All code files
    ├── Data_cleaning/               # Data preparation scripts
    ├── GNN/                         # Graph Neural Network implementation
    │   ├── GAT/                     # Graph Attention Network models
    │   ├── GCN/                     # Graph Convolutional Network models
    │   ├── graphSAGE/               # GraphSAGE models
    │   └── imbalanced/              # GNNs on imbalanced dataset (not used)
    ├── MLP/                         # Multilayer Perceptron models
    ├── RF/                          # Random Forest models
    └── XGBoost/                     # XGBoost models

Naming Conventions

Script filenames follow these conventions:

  • balanced: Uses balanced dataset (stratified sampling)
  • bootstrap: Uses bootstrap dataset (sampling with replacement)
  • Cosine/C: Uses cosine similarity for distance metric
  • Euclidean/E: Uses Euclidean distance for distance metric
  • pyg: Uses PyTorch Geometric library
  • cv: Uses cross-validation

Data Processing

The project workflow includes:

  1. Variable Selection: Logistic regression was used to identify 24 statistically significant variables (p < 0.05) from 48 potential variables related to Child Health, Parental Health, Household Environment, and Birth Conditions.
  2. Data Preprocessing:
    • Binary variables recoded (1=Yes, 0=No)
    • Numerical variables normalized using MinMaxScaler
    • Multicategorical variables converted to dummy variables
  3. Class Imbalance Management:
    • Stratified sampling: All 3,476 ADHD cases and an equal number of non-ADHD cases
    • Bootstrapping: Random sampling with replacement to create multiple balanced training sets
  4. Network Creation:
    • Each child represented as a node in the graph
    • Edges created based on feature similarity (Euclidean distance or Cosine similarity)
    • Top 20% of most similar nodes connected
    • Each node assigned a feature matrix of the 24 significant variables
  5. Model Training:
    • GNN models: GCN, GAT, GraphSAGE (3 convolutional layers with Dropout)
    • Traditional ML: RF, XGBoost, MLP (hyperparameters optimized via GridSearch)
  6. Performance Evaluation:
    • Metrics: Accuracy, Precision, Recall, F1-score
    • Special focus on Recall for ADHD cases to minimize false negatives

Getting Started

Prerequisites

Install required packages:

pip install -r requirements.txt

Running the Code

Follow these steps to reproduce the project's results:

  1. Data Cleaning and Variable Selection:

    cd nsch_project/Data_cleaning
    # Data selection and initial cleaning
    python data_select.py
    # Basic data preprocessing
    python data_process.py
    # Variable selection using logistic regression
    python variable_select_LR.py
    # Process selected variables
    python data_process_variable_select.py
  2. Generate Datasets:

    cd nsch_project/Data_cleaning
    # Create balanced dataset (stratified sampling)
    python generate_pyg_data.py
    # Create bootstrap datasets (sampling with replacement)
    python boostrap_pyg_data.py
    # Generate matrices for graph construction
    python generate_matrices.py
    # Generate datasets for ML models
    python generate_ml_data.py
  3. Train and Test GNN Models:

    cd nsch_project/GNN
    
    # GCN model
    cd GCN
    python train_gcn_pyg_balanced.py
    python test_gcn_pyg_balanced.py
    
    # GAT model
    cd ../GAT
    python train_gat_pyg_balanced.py
    python test_gat_pyg_balanced.py
    
    # GraphSAGE model
    cd ../graphSAGE
    python train_graphsage_pyg_balanced.py
    python test_graphsage_pyg_balanced.py
  4. Train and Test Traditional ML Models:

    # Random Forest models
    cd nsch_project/RF
    python rf_cv_balanced.py
    
    # XGBoost models
    cd ../XGBoost
    python xgboost_cv_balanced.py
    # On bootstrap datasets
    python xgboost_bootstrap.py
    
    # MLP models
    cd ../MLP
    python mlp_train_balanced.py
    python mlp_test_balanced.py
  5. Boostrap:

    # GraphSAGE models on bootstrap datasets
    cd nsch_project/GNN/graphSAGE
    python train_graphsage_pyg_balanced_bootstrap.py
    
    # XGBoost models on bootstrap datasets
    cd nsch_project/XGBoost
    python xgboost_cv_balanced_boostrap.py

Each step should be executed in sequence as the outputs from earlier steps are required for the later ones. Results will be saved in the respective model output directories.

Results Summary

  • Table 1. Performance Comparison of GNN Models with Stratified Sampling

    GNN results

  • Table 2. Performance Comparison of ML Models with Stratified Sampling

    ML results

  • Key findings:

    • GraphSAGE outperformed other models in terms of ADHD Recall (84.48%), showcasing its ability to accurately identify true ADHD cases.
    • XGBoost excelled in Precision for ADHD (86.63%) .
    • MLP showed exceptional performance for Non-ADHD detection (Recall of 95.02%).
    • No single model consistently outperformed the others across all metrics.

Acknowledgments

This research utilizes the National Survey of Children's Health (NSCH) 2022 public dataset, which is administered by the Health Resources and Services Administration (HRSA) Maternal and Child Health Bureau (MCHB).

License

This project uses the NSCH public dataset and follows their terms of use.

About

CS5891 - Network Analysis in Healthcare: course project (2024Fall)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published