Classification Heads
Classification heads
Any 🤗 SetFit model consists of two parts: a SentenceTransformer embedding body and a classification head.
This guide will show you:
- The built-in logistic regression classification head
- The built-in differentiable classification head
- The requirements for a custom classification head
Logistic Regression classification head
When a new SetFit model is initialized, a scikit-learn logistic regression head is chosen by default. This has been shown to be highly effective when applied on top of a finetuned sentence transformer body, and it remains the recommended classification head. Initializing a new SetFit model with a Logistic Regression head is simple:
LogisticRegression()
To initialize the Logistic Regression head (or any other head) with additional parameters, then you can use the head_params argument on SetFitModel.from_pretrained():
LogisticRegression(max_iter=300, solver='liblinear')
Differentiable classification head
SetFit also provides SetFitHead as an exclusively torch classification head. It uses a linear layer to map the embeddings to the class. It can be used by setting the use_differentiable_head argument on SetFitModel.from_pretrained() to True:
SetFitHead({'in_features': 384, 'out_features': 2, 'temperature': 1.0, 'bias': True, 'device': 'cuda'}) By default, this will assume binary classification. To change that, also set the out_features via head_params to the number of classes that you are using.
SetFitHead({'in_features': 384, 'out_features': 5, 'temperature': 1.0, 'bias': True, 'device': 'cuda'}) Unlike the default Logistic Regression head, the differentiable classification head only supports integer labels in the following range: [0, num_classes).
Training with a differentiable classification head
Using the SetFitHead unlocks some new TrainingArguments that are not used with a sklearn-based head. Note that training with SetFit consists of two phases behind the scenes: finetuning embeddings and training a classification head. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively. For a lot of these cases, the second value is only used if the classification head is differentiable. For example:
-
batch_size: (
Union[int, Tuple[int, int]], defaults to(16, 2)) - The second value in the tuple determines the batch size when training the differentiable SetFitHead. -
num_epochs: (
Union[int, Tuple[int, int]], defaults to(1, 16)) - The second value in the tuple determines the number of epochs when training the differentiable SetFitHead. In practice, thenum_epochsis usually larger for training the classification head. There are two reasons for this:- This training phase does not train with contrastive pairs, so unlike when finetuning the embedding model, you only get one training sample per labeled training text.
- This training phase involves training a classifier from scratch, not finetuning an already capable model. We need more training steps for this.
-
end_to_end: (
bool, defaults toFalse) - IfTrue, train the entire model end-to-end during the classifier training phase. Otherwise, freeze the Sentence Transformer body and only train the head. -
body_learning_rate: (
Union[float, Tuple[float, float]], defaults to(2e-5, 1e-5)) - The second value in the tuple determines the learning rate of the Sentence Transformer body during the classifier training phase. This is only relevant ifend_to_endisTrue, as otherwise the Sentence Transformer body is frozen when training the classifier. -
head_learning_rate (
float, defaults to1e-2) - This value determines the learning rate of the differentiable head during the classifier training phase. It is only used if the differentiable head is used. -
l2_weight (
float, optional) - Optional l2 weight for both the model body and head, passed to theAdamWoptimizer in the classifier training phase only if a differentiable head is used.
For example, a full training script using a differentiable classification head may look something like this:
Custom classification head
Alongside the two built-in options, SetFit allows you to specify a custom classification head. There are two forms of supported heads: a custom differentiable head or a custom non-differentiable head. Both heads must implement the following two methods:
Custom differentiable head
A custom differentiable head must follow these requirements:
- Must subclass
nn.Module. - A
predictmethod:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs]- This method classifies the embeddings. The output must integers in the range of[0, num_classes). - A
predict_probamethod:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs, num_classes]- This method classifies the embeddings into probabilities for each class. For each input, the tensor of sizenum_classesmust sum to 1. Applyingtorch.argmax(output, dim=-1)should result in the output forpredict. - A
get_loss_fnmethod:(self) -> nn.Module- Returns an initialized loss function, e.g.torch.nn.CrossEntropyLoss(). - A
forwardmethod:(self, Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]- Given the output from the Sentence Transformer body, i.e. a dictionary of'input_ids','token_type_ids','attention_mask','token_embeddings'and'sentence_embedding'keys, return a dictionary with a'logits'key and atorch.Tensorvalue with shape[batch_size, num_classes].
Custom non-differentiable head
A custom non-differentiable head must follow these requirements:
- A
predictmethod:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs]- This method classifies the embeddings. The output must integers in the range of[0, num_classes). - A
predict_probamethod:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs, num_classes]- This method classifies the embeddings into probabilities for each class. For each input, the array of sizenum_classesmust sum to 1. Applyingnp.argmax(output, dim=-1)should result in the output forpredict. - A
fitmethod:(self, np.array with shape [num_inputs, embedding_size], List[Any]) -> None- This method must take anumpyarray of embeddings and a list of corresponding labels. The labels need not be integers per se.
Many classifiers from sklearn already fit these requirements, such as RandomForestClassifier, MLPClassifier, KNeighborsClassifier, etc.
When initializing a SetFit model using your custom (non-)differentiable classification head, it is recommended to use the regular __init__ method:
Then, training and inference can commence like normal, e.g.: