This project builds an Artificial Neural Network (ANN) to predict customer churn using the "Churn_Modelling.csv" dataset. It includes complete preprocessing, model training, and a simple web app interface to make real-time predictions.
Customer churn prediction is vital for improving customer retention. Using a multi-layer ANN, this project learns patterns from customer data (like credit score, geography, age, balance, etc.) to classify whether a customer is likely to churn or stay.
.
├── Churn_Modelling.csv # Dataset
├── app.py # Web app for real-time predictions
├── experiments.ipynb # Data exploration, preprocessing, model training
├── prediction.ipynb # Manual prediction walkthrough using saved artifacts
├── model.h5 # Trained ANN model
├── scaler.pkl # StandardScaler for numeric features
├── label_encoder_gender.pkl # LabelEncoder for 'Gender'
├── onehot_encoder_geo.pkl # OneHotEncoder for 'Geography'
├── requirements.txt # Python dependencies
└── README.md # Project documentation
- Source:
[Churn_Modelling.csv](https://www.kaggle.com/datasets/shubhendra7/customer-churn-prediction) - Size: 10,000 rows × 14 columns
- Target variable:
Exited(1 → churned, 0 → retained)
Key features used:
CreditScoreGeography(France, Spain, Germany)GenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalary
-
Language: Python 3.x
-
Libraries:
TensorFlow / Keras– Model developmentpandas, numpy– Data processingscikit-learn– Preprocessing, encodersjoblib– Model and encoder savingStreamlit– Web interface
git clone https://github.com/Github-Shashwat/ANN-Classification-churn.git
cd ANN-Classification-churnIt’s recommended to use a virtual environment (e.g., venv or conda).
pip install -r requirements.txtIf you want to retrain the model:
- Open
experiments.ipynb - Follow the steps: EDA → preprocessing → training
- The trained model and encoders will be saved as
.h5and.pklfiles
streamlit run app.pyVisit http://localhost:8501 in your browser to access the interface.
-
Input layer: 11 features
-
Hidden layers:
- Dense(6) with ReLU
- Dense(6) with ReLU
-
Output layer:
- Dense(1) with Sigmoid
Loss: binary_crossentropy
Optimizer: adam
Metric: accuracy
- User enters inputs via the web form.
- Categorical features (
Gender,Geography) are encoded using saved.pklfiles. - Numerical features are scaled using the stored
StandardScaler. - The ANN model makes a prediction: churn probability and class label (0 or 1).
Input:
- Geography: France
- Gender: Female
- Age: 40
- Balance: 50000
...
Output:
- Churn Probability: 0.76
- Prediction: High Risk of Churn (1)
Check experiments.ipynb for:
- Accuracy
- Confusion Matrix
- Precision, Recall, F1-score
- Training/Validation Loss curves
- Add SHAP/Explainable AI visualizations
- Deploy app using Docker or Cloud (AWS/GCP)
- Enable CSV upload for batch prediction
- Add cross-validation and hyperparameter tuning