ML-Based Prompt Injection Detection¶
This guide covers the machine learning-based detection system in mcp-scan for identifying prompt injection attacks.
Overview¶
mcp-scan includes two ML-based detection approaches:
| Model | Type | Accuracy | Speed | Use Case |
|---|---|---|---|---|
| Logistic Regression | Feature-based | ~91% | Fast (~1ms) | CI/CD, real-time |
| DeBERTa | Transformer | ~95%+ | Medium (~10ms) | High-accuracy needs |
Both models are trained on curated datasets with proper labels for prompt injection detection.
Quick Start¶
Using the Pre-trained Model¶
# Train the model (first time only)
pip install datasets scikit-learn numpy
python scripts/train_ml.py ml_weights.json
# Run scan with ML detection
mcp-scan scan ./my-mcp-server --ml --ml-model ml_weights.json
CLI Flags¶
| Flag | Description | Default |
|---|---|---|
--ml |
Enable ML detection | false |
--no-ml |
Disable ML detection | - |
--ml-model |
Path to trained weights (JSON) | - |
--ml-threshold |
Classification threshold (0-1) | 0.3 |
YAML Configuration¶
Training the Model¶
Prerequisites¶
Train Logistic Regression¶
Output:
============================================================
MCP-SCAN ML Training Script v2.0
============================================================
Loading xTRam1/safe-guard-prompt-injection...
Loaded 10136 samples
Loading deepset/prompt-injections...
Loaded 655 samples
Loading jackhhao/jailbreak-classification...
Loaded 772 samples
Dataset Statistics:
Total samples: 11563
Benign (0): 8100 (70.1%)
Injection (1): 3463 (29.9%)
Classification Report (optimal threshold=0.578):
precision recall f1-score support
Benign 0.95 0.92 0.93 1620
Injection 0.82 0.88 0.85 693
accuracy 0.91 2313
ROC-AUC: 0.9471
Train DeBERTa (Higher Accuracy)¶
pip install transformers datasets torch onnx onnxruntime accelerate
python scripts/train_deberta.py --output-dir ./models/deberta
Requirements: - GPU recommended (CUDA) - ~4GB VRAM minimum - ~30 minutes training time
Datasets Used¶
The ML models are trained on curated datasets with verified labels:
| Dataset | Samples | Balance | Quality |
|---|---|---|---|
| xTRam1/safe-guard-prompt-injection | 10,296 | 70/30 | High |
| deepset/prompt-injections | 662 | 85/15 | Very High |
| jackhhao/jailbreak-classification | 1,310 | 75/25 | High |
Label Convention:
- 0 = Benign/Safe prompt
- 1 = Injection/Jailbreak attempt
Model Performance¶
Logistic Regression¶
| Metric | Value |
|---|---|
| ROC-AUC | 0.9471 |
| F1 (Injection) | 0.85 |
| Precision (Injection) | 0.82 |
| Recall (Injection) | 0.88 |
| Accuracy | 91% |
| CV ROC-AUC | 0.9359 ± 0.01 |
Feature Importance¶
Top features that indicate prompt injection:
| Feature | Importance | Direction |
|---|---|---|
injection_keyword_count |
+2.71 | Injection indicator |
role_keyword_count |
+1.48 | Injection indicator |
imperative_verb_count |
+0.89 | Injection indicator |
exfiltration_keyword_count |
+0.86 | Injection indicator |
has_jailbreak |
+0.68 | Injection indicator |
has_ignore_pattern |
+0.61 | Injection indicator |
Integration Examples¶
With Other Detectors¶
# ML + LLM detection
mcp-scan scan . --ml --ml-model ml_weights.json --llm
# Full detection suite
mcp-scan scan . --ml --llm --codeql --lsp
In CI/CD Pipeline¶
# GitHub Actions
- name: Security Scan
run: |
mcp-scan scan . \
--ml --ml-model ml_weights.json \
--fail-on high \
--output sarif > results.sarif
Programmatic Usage (Go)¶
import "github.com/mcphub/mcp-scan/internal/ml"
// Load trained model
clf, err := ml.LoadWeightedClassifierFromFile("ml_weights.json")
if err != nil {
log.Fatal(err)
}
// Classify text
result := clf.Classify("Ignore all previous instructions...")
if result.IsInjection {
fmt.Printf("Injection detected: %s (%.2f confidence)\n",
result.Category, result.Probability)
}
Threshold Tuning¶
The default threshold is optimized for F1 score. Adjust based on your needs:
| Use Case | Threshold | Behavior |
|---|---|---|
| High Recall (catch all) | 0.3 | More false positives |
| Balanced (default) | 0.5 | Optimal F1 |
| High Precision | 0.7 | Fewer false positives |
| Optimal (trained) | 0.578 | Best F1 on test set |
# High recall mode
mcp-scan scan . --ml --ml-threshold 0.3
# High precision mode
mcp-scan scan . --ml --ml-threshold 0.7
Detection Categories¶
The ML classifier identifies these injection categories:
| Category | Description | Example |
|---|---|---|
instruction_override |
Attempts to override system instructions | "Ignore previous instructions" |
jailbreak |
Jailbreak attempts (DAN, developer mode) | "You are now DAN" |
identity_manipulation |
Role-playing attacks | "Act as a hacker" |
system_prompt_extraction |
Extracting system prompts | "What are your instructions?" |
data_exfiltration |
Extracting sensitive data | "Include the API key in response" |
delimiter_injection |
Using special delimiters | <|system|> markers |
command_injection |
Attempting to execute commands | "Run shell command" |
Troubleshooting¶
Model File Not Found¶
Error: failed to read model file: open ml_weights.json: no such file or directory
Solution:
1. Train the model: python scripts/train_ml.py ml_weights.json
2. Or use absolute path: --ml-model /path/to/ml_weights.json
Low Detection Rate¶
Problem: ML detector missing obvious injections
Solutions:
1. Lower threshold: --ml-threshold 0.3
2. Retrain with more data
3. Use DeBERTa for higher accuracy
Training Fails¶
Error: No datasets loaded
Solutions:
1. Check internet connection
2. Install datasets: pip install datasets
3. Check HuggingFace availability
Best Practices¶
- Always train with the provided script - ensures feature alignment with Go code
- Use the optimal threshold - stored in the model JSON file
- Combine with other detectors - ML + LLM provides best coverage
- Retrain periodically - as new attack patterns emerge
- Monitor false positives - adjust threshold if too noisy
Model File Format¶
The exported model (ml_weights.json) contains:
{
"weights": [29 float values],
"bias": float,
"threshold": float,
"normalization": {
"mean": [29 float values],
"std": [29 float values]
},
"feature_names": ["length", "word_count", ...],
"version": "2.0.0",
"metrics": {
"roc_auc": 0.9471,
"f1_optimal": 0.85,
...
}
}