Train Classifier
train_binary_classifier(project_name, entity_name, run_name, dataset_repo='geekyrakshit/prompt-injection-dataset', model_name='distilbert/distilbert-base-uncased', prompt_column_name='prompt', id2label={0: 'SAFE', 1: 'INJECTION'}, label2id={'SAFE': 0, 'INJECTION': 1}, learning_rate=1e-05, batch_size=16, num_epochs=2, weight_decay=0.01, save_steps=1000, streamlit_mode=False)
Trains a binary classifier using a specified dataset and model architecture.
This function sets up and trains a binary sequence classification model using the Hugging Face Transformers library. It integrates with Weights & Biases for experiment tracking and optionally displays a progress bar in a Streamlit app.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project_name
|
str
|
The name of the Weights & Biases project. |
required |
entity_name
|
str
|
The Weights & Biases entity (user or team). |
required |
run_name
|
str
|
The name of the Weights & Biases run. |
required |
dataset_repo
|
str
|
The Hugging Face dataset repository to load. |
'geekyrakshit/prompt-injection-dataset'
|
model_name
|
str
|
The pre-trained model to use. |
'distilbert/distilbert-base-uncased'
|
prompt_column_name
|
str
|
The column name in the dataset containing the text prompts. |
'prompt'
|
id2label
|
dict[int, str]
|
Mapping from label IDs to label names. |
{0: 'SAFE', 1: 'INJECTION'}
|
label2id
|
dict[str, int]
|
Mapping from label names to label IDs. |
{'SAFE': 0, 'INJECTION': 1}
|
learning_rate
|
float
|
The learning rate for training. |
1e-05
|
batch_size
|
int
|
The batch size for training and evaluation. |
16
|
num_epochs
|
int
|
The number of training epochs. |
2
|
weight_decay
|
float
|
The weight decay for the optimizer. |
0.01
|
save_steps
|
int
|
The number of steps between model checkpoints. |
1000
|
streamlit_mode
|
bool
|
If True, integrates with Streamlit to display a progress bar. |
False
|
Returns:
Name | Type | Description |
---|---|---|
dict |
The output of the training process, including metrics and model state. |
Raises:
Type | Description |
---|---|
Exception
|
If an error occurs during training, the exception is raised after ensuring Weights & Biases run is finished. |
Source code in safeguards/train/train_classifier.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|