Skip to content

Latest commit

 

History

History
120 lines (92 loc) · 7.05 KB

File metadata and controls

120 lines (92 loc) · 7.05 KB

CAKE transformer interpretability

Classifications are keyphrase explainable! Exploring local interpretability of transformer models with keyphrase extraction

This GitHub contains the code of the experiments for our paper, and an example on how to easily use it in your pipeline. The github "CAKE Interpretability" contains the code for a preliminary version of CAKE. The trained models used for our experiments can be found here

CAKE Tranformers Interpretability

Classifications are keyphrase explainable! Exploring local interpretability of transformer models with keyphrase extraction

Abstract

Keyphrase extraction is a widely discussed topic in Natural Language Processing, as it offers a concise summary of the main topics in a document. Interpretability is also an important aspect in Machine Learning as it helps prevent socio-ethical issues, such as bias and discrimination against minorities, or mistakes that may have serious consequences. Interpretability has recently gained prominence in the field of Natural Language Processing, where transformers are the dominant architectures. The goal of interpretability is to provide interpretations that pinpoint the elements of an instance contributing the most to its decision. In this work, we use keyphrase extraction to facilitate the interpretability process, producing smaller, more concise interpretations that also consider word interactions, as keyphrases usually consist of multiple words. Additionally, our technique is based on semantic similarity, making it faster and zero-shot ready, which is ideal for online learning scenarios. We evaluated the effectiveness of our method through a series of quantitative and qualitative experiments on the well-known BERT model, comparing it against several state-of-the-art competitors.

Logo

Our logo was generated by DALLE-2 given the prompt: "Optimus prime holding a cake, 3D pixel art"

CAKE is heavy influenced by our latest work on transformers interpretability, Optimus Prime Interpretability, as it shares code, metrics, and datasets.

Applicability

  • CAKE is applicable to any transformer model able to load in FLAIR
  • Works only for Text Classification tasks

Requirements

For the requirements just check the req.txt file.

Datasets

Name Acronym
Hallmarks of Cancer HoC
Movies MV
HateXplain HX
Hummingbird HB

Competitors

Name Acronym
Baseline Attention B
Optimus Batch OB
Optimus Label OL
Integrated Gradients IG
Local Interpretable Model-agnostic Explanations LIME

Example

#Load your transformer model (e.g. BERT) using MyModel class
model_path = 'Trained Models/' 
model = 'bert_hx' 
model_name = 'bert' 
task = 'single_label' 
labels = 2 
label_names = ['no hate speech', 'hate speech']
cased = False 
model = MyModel(model_path, model, model_name, task, labels, cased)
tokenizer = model.tokenizer #Extract your tokenizer

#Load your dataset
hx = Dataset(path=data_path)
x, y, label_names, rationales = hx.load_hatexplain(tokenizer)
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(x, y, indices, test_size=.2, random_state=42)

#Load CAKE class
from cake import CAKE

#Define label descriptions for LE2 variation
descriptions = ['no hate speech label: indicates that the text is considered a normal post and does not contain any instances of hate speech.',
            'hate speech label: refers to any text that contains hate speech content, targeting a particular community or individual based on their race, gender, religion, sexual orientation, or other characteristics. These texts may express prejudice, hostility, or aggression towards a particular group or individual, and are intended to cause harm, violence or provoke a negative response.']


cake = CAKE(model_path = 'Trained Models/bert_hx', tokenizer = tokenizer, label_names = label_names, label_descriptions = descriptions, input_docs = train_texts, input_labels = train_labels, input_docs_test = test_texts)

#Then select a random instance
instance = "This sentence contains hate speech content for ****** people!"
prediction, attention, hidden_states = model.my_predict(instance) #use MyModel instance to make a prediction

#Select a label
lid = 1

# Call cake's function keyphrase_interpretation2 and setup its parameters in the following manner:
"""
    text: provide the ``instance'' for the explanation
    desired_keyphrases: the desired number of keyphrases to be exported (in this case we selected 5)
    keyphrase_method: the keyphrase variation embedding generation method according to the paper (in this case KE1 = 1)
    label_method: the label variation embedding generation method according to the paper (in this case LE1 = 1)
    width: a specified window of width $w$ (in this case 1) to locate either an exact match between keyphrases and text (w = 0) or searching the tokens within a specified window of width (w > 0)
    negatives: True if we want to keep negative similarity scores, False if we want to filter and remove the negative similarity scores (in this case False)
    tid: useful for precomputed label embeddings of (LE3), not defined in this one (None)
"""
results = cake.keyphrase_interpretation2(instance, 5, 1, 1, 1, False, None)
print([[i,j] for i,j in zip(results[1],results[2][lid]) if j>0])
# This will print the keyphrases and their influence scores

Developed by:

Name e-mail
Dimitrios Akrivousis [email protected]
Nikolaos Mylonas [email protected]
Ioannis Mollas [email protected]
Grigorios Tsoumakas [email protected]

Funded by

The research work was supported by the Hellenic Foundation forResearch and Innovation (H.F.R.I.) under the “First Call for H.F.R.I.Research Projects to support Faculty members and Researchers and the procurement of high-cost research equipment grant” (ProjectNumber: 514).

Additional resources

amulet-logo

Citation

Please cite the paper if you use it in your work or experiments :D :

  • [Conference] :
    • TBA