Skip to content

Train Classifier

train_binary_classifier(project_name, entity_name, run_name, dataset_repo='geekyrakshit/prompt-injection-dataset', model_name='distilbert/distilbert-base-uncased', prompt_column_name='prompt', id2label={0: 'SAFE', 1: 'INJECTION'}, label2id={'SAFE': 0, 'INJECTION': 1}, learning_rate=1e-05, batch_size=16, num_epochs=2, weight_decay=0.01, save_steps=1000, streamlit_mode=False)

Trains a binary classifier using a specified dataset and model architecture.

This function sets up and trains a binary sequence classification model using the Hugging Face Transformers library. It integrates with Weights & Biases for experiment tracking and optionally displays a progress bar in a Streamlit app.

Parameters:

Name Type Description Default
project_name str

The name of the Weights & Biases project.

required
entity_name str

The Weights & Biases entity (user or team).

required
run_name str

The name of the Weights & Biases run.

required
dataset_repo str

The Hugging Face dataset repository to load.

'geekyrakshit/prompt-injection-dataset'
model_name str

The pre-trained model to use.

'distilbert/distilbert-base-uncased'
prompt_column_name str

The column name in the dataset containing the text prompts.

'prompt'
id2label dict[int, str]

Mapping from label IDs to label names.

{0: 'SAFE', 1: 'INJECTION'}
label2id dict[str, int]

Mapping from label names to label IDs.

{'SAFE': 0, 'INJECTION': 1}
learning_rate float

The learning rate for training.

1e-05
batch_size int

The batch size for training and evaluation.

16
num_epochs int

The number of training epochs.

2
weight_decay float

The weight decay for the optimizer.

0.01
save_steps int

The number of steps between model checkpoints.

1000
streamlit_mode bool

If True, integrates with Streamlit to display a progress bar.

False

Returns:

Name Type Description
dict

The output of the training process, including metrics and model state.

Raises:

Type Description
Exception

If an error occurs during training, the exception is raised after ensuring Weights & Biases run is finished.

Source code in safeguards/train/train_classifier.py
def train_binary_classifier(
    project_name: str,
    entity_name: str,
    run_name: str,
    dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
    model_name: str = "distilbert/distilbert-base-uncased",
    prompt_column_name: str = "prompt",
    id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
    label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
    learning_rate: float = 1e-5,
    batch_size: int = 16,
    num_epochs: int = 2,
    weight_decay: float = 0.01,
    save_steps: int = 1000,
    streamlit_mode: bool = False,
):
    """
    Trains a binary classifier using a specified dataset and model architecture.

    This function sets up and trains a binary sequence classification model using
    the Hugging Face Transformers library. It integrates with Weights & Biases for
    experiment tracking and optionally displays a progress bar in a Streamlit app.

    Args:
        project_name (str): The name of the Weights & Biases project.
        entity_name (str): The Weights & Biases entity (user or team).
        run_name (str): The name of the Weights & Biases run.
        dataset_repo (str, optional): The Hugging Face dataset repository to load.
        model_name (str, optional): The pre-trained model to use.
        prompt_column_name (str, optional): The column name in the dataset containing
            the text prompts.
        id2label (dict[int, str], optional): Mapping from label IDs to label names.
        label2id (dict[str, int], optional): Mapping from label names to label IDs.
        learning_rate (float, optional): The learning rate for training.
        batch_size (int, optional): The batch size for training and evaluation.
        num_epochs (int, optional): The number of training epochs.
        weight_decay (float, optional): The weight decay for the optimizer.
        save_steps (int, optional): The number of steps between model checkpoints.
        streamlit_mode (bool, optional): If True, integrates with Streamlit to display
            a progress bar.

    Returns:
        dict: The output of the training process, including metrics and model state.

    Raises:
        Exception: If an error occurs during training, the exception is raised after
            ensuring Weights & Biases run is finished.
    """
    wandb.init(
        project=project_name,
        entity=entity_name,
        name=run_name,
        job_type="train-binary-classifier",
    )
    if streamlit_mode:
        st.markdown(
            f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
        )
    dataset = load_dataset(dataset_repo)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    tokenized_datasets = dataset.map(
        lambda examples: tokenizer(examples[prompt_column_name], truncation=True),
        batched=True,
    )
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    accuracy = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
    )

    trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="binary-classifier",
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_epochs,
            weight_decay=weight_decay,
            eval_strategy="epoch",
            save_strategy="steps",
            save_steps=save_steps,
            load_best_model_at_end=True,
            push_to_hub=False,
            report_to="wandb",
            logging_strategy="steps",
            logging_steps=1,
        ),
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
    )
    try:
        training_output = trainer.train()
    except Exception as e:
        wandb.finish()
        raise e
    wandb.finish()
    return training_output