Skip to content

ML Detection: Development Process

This document describes the development process, challenges encountered, and solutions implemented for the ML-based prompt injection detection system.

Problem Statement

mcp-scan needed a machine learning-based detector to identify prompt injection attacks that pattern-based rules might miss. The requirements were:

  1. High accuracy - minimize false positives and negatives
  2. Fast inference - suitable for CI/CD pipelines
  3. Portable - model trained in Python, inference in Go
  4. Maintainable - easy to retrain with new data

Initial Implementation (v1.0)

Approach

The initial implementation used: - Dataset: Lakera/mosscap_prompt_injection (223K samples) - Model: Logistic Regression - Features: 29 hand-crafted features

Problems Discovered

When running the training script, we observed:

Dataset size: 219551 samples
Positive rate: 0.09%      ← EXTREME IMBALANCE
Negative rate: 99.91%

Classification Report (Test Set):
              precision    recall  f1-score
   Injection       0.00      0.85      0.01  ← MODEL FAILED

Root Cause Analysis:

  1. The Lakera dataset has NO label field - it's raw CTF data without labels
  2. The script was reading a non-existent field, resulting in all 0s
  3. This created extreme class imbalance (0.09% positive)
  4. The model learned nothing useful

Feature Weights Were Inverted:

exfiltration_keyword_count: -18.6  ← SHOULD BE POSITIVE
has_system_prompt: -0.57          ← SHOULD BE POSITIVE


Investigation Phase

Dataset Analysis

We researched available prompt injection datasets:

Dataset Has Labels? Balance Quality
Lakera/mosscap NO N/A Raw CTF data
hackaprompt No N/A "PWNED" focused
deepset/prompt-injections YES 85/15 High quality
xTRam1/safe-guard YES 70/30 Good balance
jackhhao/jailbreak YES 75/25 Jailbreak specific

Key Finding: Many popular datasets either: - Have no labels (Lakera, hackaprompt) - Are focused on specific attack patterns ("I have been PWNED") - Have severe class imbalance

Dataset Schema Verification

# Lakera dataset - NO label field
{
    "level": "Level 1",
    "prompt": "user input",
    "answer": "filtered response",
    "raw_answer": "unfiltered response"
}

# deepset dataset - HAS label field
{
    "text": "prompt text",
    "label": 0 or 1,  # 0=benign, 1=injection
    "split": "train"
}

Solution Implementation (v2.0)

ETL Pipeline Redesign

1. Dataset Selection

Selected datasets with verified labels: - xTRam1/safe-guard-prompt-injection (10K, 70/30 balance) - deepset/prompt-injections (662, high quality) - jackhhao/jailbreak-classification (772, jailbreak-specific)

2. Schema Normalization

def load_datasets() -> List[Sample]:
    # Each dataset has different schema

    # xTRam1: {text, label}
    for item in ds['train']:
        add_sample(item['text'], item['label'], 'xTRam1')

    # jackhhao: {prompt, type} where type is "benign"/"jailbreak"
    for item in ds['train']:
        label = 1 if item['type'] == 'jailbreak' else 0
        add_sample(item['prompt'], label, 'jackhhao')

3. Deduplication

seen_hashes = set()

def add_sample(text, label, source):
    text_hash = hashlib.md5(text.encode()).hexdigest()
    if text_hash in seen_hashes:
        return False  # Duplicate
    seen_hashes.add(text_hash)
    samples.append(Sample(text, label, source))
    return True

Feature Alignment

Critical Issue: Python and Go must extract identical features.

Solution: Synchronized all keyword lists and regex patterns:

# Python (train_ml.py)
INJECTION_KEYWORDS = [
    'ignore', 'disregard', 'forget', 'override', 'bypass',
    'previous', 'prior', 'above', 'system', 'instructions',
    'prompt', 'rules', 'guidelines', 'restrictions',
]
// Go (features.go)
var injectionKeywords = []string{
    "ignore", "disregard", "forget", "override", "bypass",
    "previous", "prior", "above", "system", "instructions",
    "prompt", "rules", "guidelines", "restrictions",
}

Threshold Optimization

Instead of using fixed threshold 0.5, we optimize for F1:

precisions, recalls, thresholds = precision_recall_curve(y_test, y_prob)
f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-8)
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx]  # = 0.578

Results Comparison

Before (v1.0)

Metric Value
Dataset size 219,551
Positive rate 0.09%
Precision (Injection) 0.00
Recall (Injection) 0.85
F1 (Injection) 0.01
ROC-AUC 0.86

After (v2.0)

Metric Value
Dataset size 11,563
Positive rate 29.9%
Precision (Injection) 0.82
Recall (Injection) 0.88
F1 (Injection) 0.85
ROC-AUC 0.9471

Feature Importance (Corrected)

1. injection_keyword_count  : +2.71 (injection indicator) ✓
2. role_keyword_count       : +1.48 (injection indicator) ✓
3. imperative_verb_count    : +0.89 (injection indicator) ✓
4. has_jailbreak            : +0.68 (injection indicator) ✓
5. has_ignore_pattern       : +0.61 (injection indicator) ✓

The positive coefficients now correctly indicate injection patterns.


DeBERTa Implementation

For users needing higher accuracy, we added DeBERTa fine-tuning:

Architecture

Input Text
DeBERTa Tokenizer (max_length=512)
DeBERTa-v3-base (pretrained)
Classification Head (2 classes)
Softmax → Probability

Training Configuration

TrainingArguments(
    num_train_epochs=3,
    per_device_train_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

ONNX Export

For production inference without Python:

torch.onnx.export(
    model,
    (input_ids, attention_mask),
    "model.onnx",
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes={'input_ids': {0: 'batch', 1: 'seq'}},
    opset_version=14,
)

Lessons Learned

1. Always Verify Dataset Labels

Lesson: Never assume a dataset has the fields you expect.

Practice: Always inspect sample records before training:

ds = load_dataset("some/dataset")
print(ds['train'][0])  # Check actual schema

2. Class Imbalance Detection

Lesson: Extreme imbalance (>10:1) indicates data problems.

Practice: Add early warning in training scripts:

if y.mean() < 0.1 or y.mean() > 0.9:
    print("WARNING: Significant class imbalance detected!")

3. Feature Sign Validation

Lesson: Feature coefficients should match intuition.

Practice: Always print feature importance and verify signs:

for name, coef in sorted(zip(names, model.coef_[0]), key=lambda x: -abs(x[1])):
    direction = "injection" if coef > 0 else "benign"
    print(f"{name}: {coef:+.4f} ({direction} indicator)")

4. Cross-Language Feature Alignment

Lesson: Python training and Go inference must be identical.

Practice: - Use constants for keyword lists - Share regex patterns via documentation - Write cross-language tests


Future Improvements

Short Term

  1. Add more datasets as they become available
  2. Implement SMOTE for additional augmentation
  3. Add multilingual support (non-English prompts)

Medium Term

  1. Ensemble methods - combine Logistic Regression + DeBERTa
  2. Active learning - improve on misclassified examples
  3. Domain adaptation - fine-tune for specific use cases

Long Term

  1. Online learning - update model with new attacks
  2. Adversarial training - robust to evasion attempts
  3. Explainability - highlight which parts triggered detection

References

Datasets

Research

Tools