Skip to content

Prompt Safety Classification: Machine learning models for detecting safe vs unsafe AI prompts to enhance conversational AI security by filtering adversarial and jailbreak inputs.

Notifications You must be signed in to change notification settings

ananyapattaje/Prompt_Safety_Classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

🪼 Project Overview

This project aims to build a robust machine learning classifier to distinguish between safe and unsafe/malicious AI prompts. Unsafe prompts include adversarial "prompt injections" and jailbreak attacks designed to bypass AI safety measures. By leveraging diverse datasets, the goal is to enhance the security and reliability of conversational AI systems.

🪼 Datasets

  • Safe Prompts: Collected from AI-generated prompt datasets, prompt engineering examples, and community-curated ChatGPT prompts. Labeled as 0.
  • Unsafe Prompts: Includes forbidden question sets, jailbreak prompts, and malicious prompt collections from real-world adversarial sources. Labeled as 1.
  • Combined dataset contains over 80,000 unique prompts after cleaning with a class imbalance favoring safe prompts.

🪼 Data Preprocessing

  • Loaded and unified multiple datasets with consistent labeling and column naming.
  • Removed duplicate and null entries.
  • Analyzed prompt length distribution, finding a majority of short prompts.

🪼 Feature Extraction

  • Converted text prompts into numerical vectors using:
    • TF-IDF Vectorizer (max 5000 features, removing English stop words)
    • Bag of Words Vectorizer (max 5000 features)

🪼 Modeling

  • Used Logistic Regression and Random Forest classifiers.
  • Trained with an 80/20 train-test split.
  • Achieved excellent performance:
    • Accuracy ~99.5–99.8%
    • Precision and recall near 1.0 for both classes
    • AUC ROC of 1.00 for all models

🪼 Model Evaluation

  • Confusion matrices confirm very low misclassification rates.
  • ROC curves demonstrate near-perfect ability to separate safe from unsafe prompts.
  • Example predictions show consistent classification of benign and malicious inputs.

🪼 How to Use

  1. Preprocess your prompts using the vectorizers.
  2. Load a trained model.
  3. Use the predict_class(model, vectorizer, text) function to classify new prompts.

🪼 Future Improvements

  • Integrate transformer-based or deep learning classifiers.
  • Enrich the adversarial dataset with more complex prompt injections.
  • Implement ensemble methods and threshold optimization for enhanced detection.

🪼 Author

Ananya P S
20221CSD0106


For detailed code, dataset processing, and evaluation, please refer to the accompanying project notebook.

About

Prompt Safety Classification: Machine learning models for detecting safe vs unsafe AI prompts to enhance conversational AI security by filtering adversarial and jailbreak inputs.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published