Zero Shot Classification With Embeddings
Zero-shot classification with embeddings
In this notebook we will classify the sentiment of reviews using embeddings and zero labeled data! The dataset is created in the Get_embeddings_from_dataset Notebook.
We'll define positive sentiment to be 4- and 5-star reviews, and negative sentiment to be 1- and 2-star reviews. 3-star reviews are considered neutral and we won't use them for this example.
We will perform zero-shot classification by embedding descriptions of each class and then comparing new samples to those class embeddings."
Loading extension script from `C:\Users\dicolomb\.nuget\packages\microsoft.dotnet.interactive.aiutilities\1.0.0-beta.24054.2\interactive-extensions\dotnet\extension.dib`
Zero-Shot Classification
To perform zero shot classification, we want to predict labels for our samples without any training. To do this, we can simply embed short descriptions of each label, such as positive and negative, and then compare the cosine distance between embeddings of samples and label descriptions.
The highest similarity label to the sample input is the predicted label. We can also define a prediction score to be the difference between the cosine distance to the positive and to the negative label. This score can be used for plotting a precision-recall curve, which can be used to select a different tradeoff between precision and recall, by selecting a different threshold.
The code defines two public records, Label and LabelledItem. The Label record represents a label with its associated text and embedding. The LabelledItem record represents an item with its associated product ID, summary, text, score, label, predicted label, and probability.
The PredictLabels method is used to predict labels for a given set of data. It takes three parameters: positiveLabel and negativeLabel which are strings representing the labels for positive and negative sentiments, and data which is an enumerable collection of DataRow objects representing the data to be classified.
Inside the method, a list of Label objects is created. Then, the method calculates the average embedding for each label. It does this by filtering the data based on the Score property, then aggregating the Embedding property of each item. This is done separately for positive and negative scores.
After calculating the average embeddings, the method creates new Label objects with the calculated embeddings and adds them to the labels list.
Finally, the method creates a list of LabelledItem objects by iterating over the data. For each item in data, it calculates the similarity score with each label in the labels list, selects the label with the highest score, and creates a new LabelledItem with this information. The list of LabelledItem objects is then returned.
First, an instance of MLContext is created. MLContext is the main entry point for working with ML.NET, providing methods and properties for loading data, creating machine learning models, and more.
Next, a dataView is created by loading data from an enumerable collection of predictions. The LoadFromEnumerable method is used to load the data, and it's transforming the predictions collection into a new anonymous type with three properties: Label, PredictedLabel, and Probability. The Label and PredictedLabel properties are set to 1f if the corresponding label is "positive", and 0f otherwise. The Probability property is simply the Probability property of the prediction.
After the data is loaded, the Evaluate method of the BinaryClassification catalog is called on the context object. This method computes various metrics that can be used to evaluate the performance of a binary classification model. The dataView is passed as the first argument, and the names of the label and score columns are specified as "Label" and "PredictedLabel", respectively.
Finally, the Display method is called on the metric object to print the evaluation metrics to the console.
In terms of improvements, the code is quite efficient and readable as it is. However, you might consider adding comments to explain what each line of code does, especially if this code will be read by others who may not be familiar with ML.NET.