Skip to content

Train Classifier

StreamlitProgressbarCallback

Bases: TrainerCallback

StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer that integrates a progress bar into a Streamlit application. This class updates the progress bar at each training step, providing real-time feedback on the training process within the Streamlit interface.

Attributes:

Name Type Description
progress_bar DeltaGenerator

A Streamlit progress bar object initialized to 0 with the text "Training".

Methods:

Name Description
on_step_begin

Updates the progress bar at the beginning of each training step. The progress is calculated as the percentage of completed steps out of the total steps. The progress bar text is updated to show the current step and the total steps.

Source code in guardrails_genie/train_classifier.py
class StreamlitProgressbarCallback(TrainerCallback):
    """
    StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
    that integrates a progress bar into a Streamlit application. This class updates
    the progress bar at each training step, providing real-time feedback on the
    training process within the Streamlit interface.

    Attributes:
        progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
            bar object initialized to 0 with the text "Training".

    Methods:
        on_step_begin(args, state, control, **kwargs):
            Updates the progress bar at the beginning of each training step. The progress
            is calculated as the percentage of completed steps out of the total steps.
            The progress bar text is updated to show the current step and the total steps.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.progress_bar = st.progress(0, text="Training")

    def on_step_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        super().on_step_begin(args, state, control, **kwargs)
        self.progress_bar.progress(
            (state.global_step * 100 // state.max_steps) + 1,
            text=f"Training {state.global_step} / {state.max_steps}",
        )

train_binary_classifier(project_name, entity_name, run_name, dataset_repo='geekyrakshit/prompt-injection-dataset', model_name='distilbert/distilbert-base-uncased', prompt_column_name='prompt', id2label={0: 'SAFE', 1: 'INJECTION'}, label2id={'SAFE': 0, 'INJECTION': 1}, learning_rate=1e-05, batch_size=16, num_epochs=2, weight_decay=0.01, save_steps=1000, streamlit_mode=False)

Trains a binary classifier using a specified dataset and model architecture.

This function sets up and trains a binary sequence classification model using the Hugging Face Transformers library. It integrates with Weights & Biases for experiment tracking and optionally displays a progress bar in a Streamlit app.

Parameters:

Name Type Description Default
project_name str

The name of the Weights & Biases project.

required
entity_name str

The Weights & Biases entity (user or team).

required
run_name str

The name of the Weights & Biases run.

required
dataset_repo str

The Hugging Face dataset repository to load.

'geekyrakshit/prompt-injection-dataset'
model_name str

The pre-trained model to use.

'distilbert/distilbert-base-uncased'
prompt_column_name str

The column name in the dataset containing the text prompts.

'prompt'
id2label dict[int, str]

Mapping from label IDs to label names.

{0: 'SAFE', 1: 'INJECTION'}
label2id dict[str, int]

Mapping from label names to label IDs.

{'SAFE': 0, 'INJECTION': 1}
learning_rate float

The learning rate for training.

1e-05
batch_size int

The batch size for training and evaluation.

16
num_epochs int

The number of training epochs.

2
weight_decay float

The weight decay for the optimizer.

0.01
save_steps int

The number of steps between model checkpoints.

1000
streamlit_mode bool

If True, integrates with Streamlit to display a progress bar.

False

Returns:

Name Type Description
dict

The output of the training process, including metrics and model state.

Raises:

Type Description
Exception

If an error occurs during training, the exception is raised after ensuring Weights & Biases run is finished.

Source code in guardrails_genie/train_classifier.py
def train_binary_classifier(
    project_name: str,
    entity_name: str,
    run_name: str,
    dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
    model_name: str = "distilbert/distilbert-base-uncased",
    prompt_column_name: str = "prompt",
    id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
    label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
    learning_rate: float = 1e-5,
    batch_size: int = 16,
    num_epochs: int = 2,
    weight_decay: float = 0.01,
    save_steps: int = 1000,
    streamlit_mode: bool = False,
):
    """
    Trains a binary classifier using a specified dataset and model architecture.

    This function sets up and trains a binary sequence classification model using
    the Hugging Face Transformers library. It integrates with Weights & Biases for
    experiment tracking and optionally displays a progress bar in a Streamlit app.

    Args:
        project_name (str): The name of the Weights & Biases project.
        entity_name (str): The Weights & Biases entity (user or team).
        run_name (str): The name of the Weights & Biases run.
        dataset_repo (str, optional): The Hugging Face dataset repository to load.
        model_name (str, optional): The pre-trained model to use.
        prompt_column_name (str, optional): The column name in the dataset containing
            the text prompts.
        id2label (dict[int, str], optional): Mapping from label IDs to label names.
        label2id (dict[str, int], optional): Mapping from label names to label IDs.
        learning_rate (float, optional): The learning rate for training.
        batch_size (int, optional): The batch size for training and evaluation.
        num_epochs (int, optional): The number of training epochs.
        weight_decay (float, optional): The weight decay for the optimizer.
        save_steps (int, optional): The number of steps between model checkpoints.
        streamlit_mode (bool, optional): If True, integrates with Streamlit to display
            a progress bar.

    Returns:
        dict: The output of the training process, including metrics and model state.

    Raises:
        Exception: If an error occurs during training, the exception is raised after
            ensuring Weights & Biases run is finished.
    """
    wandb.init(project=project_name, entity=entity_name, name=run_name)
    if streamlit_mode:
        st.markdown(
            f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
        )
    dataset = load_dataset(dataset_repo)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    tokenized_datasets = dataset.map(
        lambda examples: tokenizer(examples[prompt_column_name], truncation=True),
        batched=True,
    )
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    accuracy = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
    )

    trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir="binary-classifier",
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_epochs,
            weight_decay=weight_decay,
            eval_strategy="epoch",
            save_strategy="steps",
            save_steps=save_steps,
            load_best_model_at_end=True,
            push_to_hub=False,
            report_to="wandb",
            logging_strategy="steps",
            logging_steps=1,
        ),
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
    )
    try:
        training_output = trainer.train()
    except Exception as e:
        wandb.finish()
        raise e
    wandb.finish()
    return training_output