ML Detection: Development Process¶
This document describes the development process, challenges encountered, and solutions implemented for the ML-based prompt injection detection system.
Problem Statement¶
mcp-scan needed a machine learning-based detector to identify prompt injection attacks that pattern-based rules might miss. The requirements were:
- High accuracy - minimize false positives and negatives
- Fast inference - suitable for CI/CD pipelines
- Portable - model trained in Python, inference in Go
- Maintainable - easy to retrain with new data
Initial Implementation (v1.0)¶
Approach¶
The initial implementation used:
- Dataset: Lakera/mosscap_prompt_injection (223K samples)
- Model: Logistic Regression
- Features: 29 hand-crafted features
Problems Discovered¶
When running the training script, we observed:
Dataset size: 219551 samples
Positive rate: 0.09% ← EXTREME IMBALANCE
Negative rate: 99.91%
Classification Report (Test Set):
precision recall f1-score
Injection 0.00 0.85 0.01 ← MODEL FAILED
Root Cause Analysis:
- The Lakera dataset has NO
labelfield - it's raw CTF data without labels - The script was reading a non-existent field, resulting in all 0s
- This created extreme class imbalance (0.09% positive)
- The model learned nothing useful
Feature Weights Were Inverted:
exfiltration_keyword_count: -18.6 ← SHOULD BE POSITIVE
has_system_prompt: -0.57 ← SHOULD BE POSITIVE
Investigation Phase¶
Dataset Analysis¶
We researched available prompt injection datasets:
| Dataset | Has Labels? | Balance | Quality |
|---|---|---|---|
| Lakera/mosscap | NO | N/A | Raw CTF data |
| hackaprompt | No | N/A | "PWNED" focused |
| deepset/prompt-injections | YES | 85/15 | High quality |
| xTRam1/safe-guard | YES | 70/30 | Good balance |
| jackhhao/jailbreak | YES | 75/25 | Jailbreak specific |
Key Finding: Many popular datasets either: - Have no labels (Lakera, hackaprompt) - Are focused on specific attack patterns ("I have been PWNED") - Have severe class imbalance
Dataset Schema Verification¶
# Lakera dataset - NO label field
{
"level": "Level 1",
"prompt": "user input",
"answer": "filtered response",
"raw_answer": "unfiltered response"
}
# deepset dataset - HAS label field
{
"text": "prompt text",
"label": 0 or 1, # 0=benign, 1=injection
"split": "train"
}
Solution Implementation (v2.0)¶
ETL Pipeline Redesign¶
1. Dataset Selection
Selected datasets with verified labels:
- xTRam1/safe-guard-prompt-injection (10K, 70/30 balance)
- deepset/prompt-injections (662, high quality)
- jackhhao/jailbreak-classification (772, jailbreak-specific)
2. Schema Normalization
def load_datasets() -> List[Sample]:
# Each dataset has different schema
# xTRam1: {text, label}
for item in ds['train']:
add_sample(item['text'], item['label'], 'xTRam1')
# jackhhao: {prompt, type} where type is "benign"/"jailbreak"
for item in ds['train']:
label = 1 if item['type'] == 'jailbreak' else 0
add_sample(item['prompt'], label, 'jackhhao')
3. Deduplication
seen_hashes = set()
def add_sample(text, label, source):
text_hash = hashlib.md5(text.encode()).hexdigest()
if text_hash in seen_hashes:
return False # Duplicate
seen_hashes.add(text_hash)
samples.append(Sample(text, label, source))
return True
Feature Alignment¶
Critical Issue: Python and Go must extract identical features.
Solution: Synchronized all keyword lists and regex patterns:
# Python (train_ml.py)
INJECTION_KEYWORDS = [
'ignore', 'disregard', 'forget', 'override', 'bypass',
'previous', 'prior', 'above', 'system', 'instructions',
'prompt', 'rules', 'guidelines', 'restrictions',
]
// Go (features.go)
var injectionKeywords = []string{
"ignore", "disregard", "forget", "override", "bypass",
"previous", "prior", "above", "system", "instructions",
"prompt", "rules", "guidelines", "restrictions",
}
Threshold Optimization¶
Instead of using fixed threshold 0.5, we optimize for F1:
precisions, recalls, thresholds = precision_recall_curve(y_test, y_prob)
f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx] # = 0.578
Results Comparison¶
Before (v1.0)¶
| Metric | Value |
|---|---|
| Dataset size | 219,551 |
| Positive rate | 0.09% |
| Precision (Injection) | 0.00 |
| Recall (Injection) | 0.85 |
| F1 (Injection) | 0.01 |
| ROC-AUC | 0.86 |
After (v2.0)¶
| Metric | Value |
|---|---|
| Dataset size | 11,563 |
| Positive rate | 29.9% |
| Precision (Injection) | 0.82 |
| Recall (Injection) | 0.88 |
| F1 (Injection) | 0.85 |
| ROC-AUC | 0.9471 |
Feature Importance (Corrected)¶
1. injection_keyword_count : +2.71 (injection indicator) ✓
2. role_keyword_count : +1.48 (injection indicator) ✓
3. imperative_verb_count : +0.89 (injection indicator) ✓
4. has_jailbreak : +0.68 (injection indicator) ✓
5. has_ignore_pattern : +0.61 (injection indicator) ✓
The positive coefficients now correctly indicate injection patterns.
DeBERTa Implementation¶
For users needing higher accuracy, we added DeBERTa fine-tuning:
Architecture¶
Input Text
↓
DeBERTa Tokenizer (max_length=512)
↓
DeBERTa-v3-base (pretrained)
↓
Classification Head (2 classes)
↓
Softmax → Probability
Training Configuration¶
TrainingArguments(
num_train_epochs=3,
per_device_train_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
weight_decay=0.01,
eval_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="f1",
)
ONNX Export¶
For production inference without Python:
torch.onnx.export(
model,
(input_ids, attention_mask),
"model.onnx",
input_names=['input_ids', 'attention_mask'],
output_names=['logits'],
dynamic_axes={'input_ids': {0: 'batch', 1: 'seq'}},
opset_version=14,
)
Lessons Learned¶
1. Always Verify Dataset Labels¶
Lesson: Never assume a dataset has the fields you expect.
Practice: Always inspect sample records before training:
2. Class Imbalance Detection¶
Lesson: Extreme imbalance (>10:1) indicates data problems.
Practice: Add early warning in training scripts:
3. Feature Sign Validation¶
Lesson: Feature coefficients should match intuition.
Practice: Always print feature importance and verify signs:
for name, coef in sorted(zip(names, model.coef_[0]), key=lambda x: -abs(x[1])):
direction = "injection" if coef > 0 else "benign"
print(f"{name}: {coef:+.4f} ({direction} indicator)")
4. Cross-Language Feature Alignment¶
Lesson: Python training and Go inference must be identical.
Practice: - Use constants for keyword lists - Share regex patterns via documentation - Write cross-language tests
Future Improvements¶
Short Term¶
- Add more datasets as they become available
- Implement SMOTE for additional augmentation
- Add multilingual support (non-English prompts)
Medium Term¶
- Ensemble methods - combine Logistic Regression + DeBERTa
- Active learning - improve on misclassified examples
- Domain adaptation - fine-tune for specific use cases
Long Term¶
- Online learning - update model with new attacks
- Adversarial training - robust to evasion attempts
- Explainability - highlight which parts triggered detection
References¶
Datasets¶
Research¶
- Evaluating Prompt Injection Datasets
- DeBERTa: Decoding-enhanced BERT
- Fine-tuned LLMs for Prompt Injection Detection