XAI Navigator: Find the method that makes your AI model more interpretable

from | 1 December 2023 | Tech Deep Dive

Introduction

As artificial intelligence (AI) continues to advance and play an increasingly significant role in our lives, the need for explainable AI (XAI) has become more important than ever. XAI is a rapidly growing field of research that seeks to create AI systems that are transparent, interpretable, and understandable by humans. The importance of XAI arises from the fact that traditional AI models often function as "black boxes," making it difficult to understand how they arrive at their decisions or recommendations. As AI systems become more prevalent and critical in high-stakes domains such as healthcare, finance, and autonomous driving, the ability to understand and explain their decisions becomes paramount. But even non-critical AI systems with low stakes can benefit from XAI by improving the collaboration between humans and machines. XAI can enable humans to better understand how AI systems arrive at their decisions, making it easier to identify and correct errors or to learn from the algorithm.

If you want to learn more about the benefits of XAI, check out our blog article

In recent years, the field of explainable AI (XAI) has experienced rapid growth and innovation, with new methods and techniques emerging to tackle the challenge of creating transparent and interpretable AI models. These methods range from visualisation and explanation techniques to new model architectures and training algorithms.

This article aims to assist you in solving the interpretability problem of your AI model by presenting a XAI Navigator. By answering a series of questions, you can identify the most appropriate method to address your specific needs. To facilitate your implementation of these methods, we provide links to code and resources for further reading.

In section 3, we demonstrate how our XAI Navigator can be applied to a real-world use case. Drawing on our extensive experience in over 1,500 customer projects related to data and AI, we designed this framework to be applicable to most interpretability challenges. If your specific problem does not fit within the framework, we are more than happy to collaborate with you to develop a custom solution.

The XAI Navigator

Keep your concrete use case in mind and start at question 1. After answering a few questions, we will propose you a solution with methods that make your AI algorithms more explainable.

Start Here: Do you want to analyse an existing model?

a) Yesgo to 2.1
b) No, I want to build an interpretable model from scratchgo to 2.2

2.1 Existing model

If you already have a well working model, you don't need to change it. Instead, you can do a post-hoc analysis of its decision making.

Do you want to understand the model in general, or do you want explanations for specific decisions?

a) The model in generalgo to 3.1
b) Specific decisionsgo to 3.2

2.2 Building an interpretable model

If you are building up a new AI model from scratch, you have the advantage, that you can consider the interpretability for its design. The choice of your model is a trade-off between performance and interpretability. Simple models like a linear regression are highly interpretable but are limited in their predictive performance. Complex models like deep neural networks are performing well if they have enough training data available, but they lack interpretability. Once you have trained your model, it is possible to further improve its interpretability by applying additional post-hoc analysis methods. (go to 2.1)

There is a variety of models that we sort here by their interpretability:

Linear Regression / Logistic Regression: Linear models are fully transparent with a prediction made by a simple formula. The influence of features is clear and can be immediately explained by the model coefficients. Linear models are limited because they are unable to model non-linear relations.

scikit-learn module

RuleFit: Simple rules are created to represent different data populations. For example, the rule temperature > 30 °C and precipitation<1mm would represent days with hot weather. A linear model is assigning scores to each rule and for an input the scores from all matching rules are added up to the model prediction. Knowing how the model handles specific situations is a very natural form of interpretability. For training the RuleFit-model you need to consider the trade-off between using too many rules, which makes an overview on the model harder and too less rules, that lead to a bad performance.

Python code

Decision Trees: A decision tree consists of nodes and branches. The nodes represent decisions or choices, and the branches represent the outcomes of those decisions. The top node of the tree, called the root node, represents the starting point of the decision-making process. The leaves of the tree represent the final outcomes or predictions. The simple decision rules are easy to understand also for lay persons. However, if the tree is too big, it can become too complex. Therefore, you can restrict the tree by a maximum size or apply pruning techniques to simplify a complex tree. Small decision trees have a very limited model capacity and are only suited to map simple input-output relations.

scikit-learn module

K-Nearest Neighbours: If you prefer examples to explain decisions, K-Nearest Neighbours might have the right type of interpretability for you. The prediction on an input is made by finding the neighbourhood of the most similar inputs from a dataset and predicting the most frequent output that is labelled for them. This works on tabular data but can also be extended to complex data types like text or images. In a variant named DkNN you take a neural network and measure similarity as a distance in the latent space of a hidden layer close to the output layer.

scikit-learn module

Naïve Bayes: By ignoring the correlation between different input features, Naïve Bayes is making classifications based on clear formulas that can be easily decomposed into a set of one formula per input feature. Like this it is easy to see how one feature is influencing the prediction. The assumption, that the features are uncorrelated, can be problematic, but practically Naïve Bayes works well on many use cases.

scikit-learn module

Generalised Additive Model (GAM): The inability by linear models to map non-linear relation can be fixed with a GAM. It linearly combines non-linear spline functions that are more flexible. Each spline function can be plotted to get a visual impression, what relationship between input feature and output is mapped by them. Of course, a GAM is more interpretable with a lower number of splines.

Read more

Python package

Symbolic Regression: In some fields a concrete formula is the best model, because it reflects known relationships like physical or economic laws. Symbolic regression can be used for finding such formulas. It is possible to balance the complexity with their predictive power to find a simple formula with a high accuracy.

Read more

Python package

Neural Circuit Policies: Neural networks are usually too complex for being interpretable. Neural circuit policies solve this problem because only very few neurons are responsible for the decision making. Analysing them is much easier than analysing a normal neural network. At the same time, they can be performant enough for tasks like autonomous driving.

Read more

Python code

XGBoost: The XGBoost model itself consists of too many decision trees for being interpretable. At the same time, it is one of the most performant choices for modelling tabular data and comes with integrated explainability methods. TreeSHAP is a variant of SHAP (see 6.3) that is optimised for tree-based models like XGBoost and can show which input features are influential.

XGBoost documentation

Example code for applying TreeSHAP on XGBoost

3.1 Global Analysis

A global analysis aims to explain the model in general. It can be an important check if the model behaves in the right way. Especially in high-stakes domains it can be important to detect false biases or defective decision processes.

What do you want to understand about the model?

a) Influence of the input featuresgo to 4.1
b) Internal decision rules of the modelgo to 4.2
c) The data on which the model was trained ongo to 4.3
d) Which high level patterns are important in the decision-making processgo to 4.4

3.2 Local Analysis

A local analysis is a valuable tool for individuals working with AI models, as it provides transparency into the model's decision-making process for specific inputs. This insight can be used to learn from the model, assess the plausibility of its decisions, and integrate the model's predictions with human expert knowledge to make better-informed decisions. Overall, local analysis can enhance collaboration between humans and AI systems, leading to more effective and trustworthy outcomes.

What do you want to learn about specific decisions?

a) Influence of the input featuresgo to 4.5
b) Internal decision rules of the model that apply on this decisiongo to 4.6
c) Why was no different prediction madego to 4.7

4.1 Global Input Feature Influence

By identifying which features have the most significant impact on an AI model's predictions, we can gain a clearer understanding of its decision-making process. An additional benefit is that we can determine the value of different features for the model's predictions, which can be used to improve its performance. By optimising the most influential features, we can enhance the model's accuracy and reliability, ultimately resulting in more effective and trustworthy predictions.

How much do you want to know about the influence of features?

a) Knowing a score, how important they are for the model is sufficientgo to 5.1
b) I want to understand how changing a feature affects the model predictiongo to 5.2

4.2 Global Surrogate Models

A global surrogate model is a simplified and interpretable model that approximates the behaviour of the original black-box model, making it easier to understand. However, the accuracy of the surrogate model depends on its ability to accurately mimic the original model, which can be challenging for complex models. For cases with a low accuracy, the interpretability provided by the surrogate model may be limited. An alternative can be local surrogate models. (go to 4.6)

See 2.2 for surrogate models that you could use.

Read more about surrogate models

4.3 Prototypes and Criticism

Training data is a crucial component of machine learning models, as it directly affects what the model learns and how it makes predictions. Prototypes are a useful tool for gaining insight into the training data by identifying a few representative data points that capture the essence of different subgroups within the dataset. By analysing these prototypes, users can quickly gain an overview of the entire dataset without having to examine every single data point.

Criticisms, on the other hand, are data points that are representative of an unusual or rare subgroup within the dataset. These criticisms can help users understand for which cases the model might perform poorly, because it did not get enough similar training data. For example, in a dataset of tourist images from New York City, prototypes might be typical images of the Empire State Building and the Brooklyn Bridge, while criticisms might be more unique images that represent minorities in the dataset, such as a street performer or a lesser-known landmark.

Prototypes and criticisms can also be combined with local analysis ( go to 3.2) where you explain the most representative data inputs in depth to understand the model in general.

A recommended algorithm for finding prototypes and criticisms is ProtoDash:

Read more about ProtoDash

ProtoDash is implemented in this XAI library

4.4 Analysing High-Level Patterns

Understanding AI models that work on data like images or sound is typically harder than analysing AI models for tabular data. When we analyse an autonomous driving model, we are not interested in how single pixels are processed. As humans we rather want to understand how objects like bicycles, which consist of many pixels in an image, are treated by the model. This can help us to uncover harmful decision processes. We can't give you a general recommendation for methods since a solution here would need to be customised to your specific use case. However, we present you some useful methods that you could consider:

DeepDream is generating fascinating images that are worth to check out from an artistic perspective. At the same time this tool can be used, to see for which visual patterns specific neurons of a neural network react the most.

Read more about DeepDream

Network Dissection can check if certain neurons of your neural network correspond to one category of a labelled set of more than thousand categories.

Read more about Network Dissection

Activation Atlases are very useful to get an overview of patterns in images that a neural network is using and how they are related.

Read more about Activation Atlases (the visualisations in this article are really beautiful)

If you need a custom solution for your problem, we are happy to discuss it in a workshop with you. With experience in various Explainable AI projects and State of the art researchwe can offer tailored XAI solutions that fit your specific needs.

4.5 Local input feature influence

Knowing how input features influenced specific predictions makes the decisions more transparent.

How much do you want to know about the influence of features?

a) Knowing a score, how important they are for this decision, is sufficientgo to 5.3
b) I want to understand how changing a feature affects this decisiongo to 5.4

4.6 LIME

LIME (Local Interpretable Model-Agnostic Explanations) involves training an interpretable model that mimics the behaviour of the AI model. This model learns how the AI model changes its predictions when there are slight changes in the input. By analysing this interpretable model, we can gain a better understanding of how the AI model arrived at its decision for a particular input.

Read more about LIME

Python library for LIME

4.7 Counterfactuals

When people don't trust a decision made by an AI, a natural reaction is to ask, why the AI did not make the expected decision. Counterfactuals show, how the input needs to be changed, so the AI model makes the desired decision. Knowing how to change a decision can be a satisfying explanation for single decisions but also helps to understand the limits of a models better.

What is the data type of your input?

a) Tabular data like numbers and categoriesgo to 5.5
b) Image, sound or other data types that require a neural networkgo to 5.6

5.1 Global Feature Importance

Feature importance methods have a very simple interpretation. A high score for a feature tells us that it is very influential for the model. If the score is close to zero, the feature is irrelevant and can be removed in future model iterations. Knowing where the model is looking at is good check to spot false biases or discriminative patterns. It can also be a nice cross check, if the model is looking at the same features as a domain expert would do.

What is the data type of your input?

a) Tabular data like numbers and categoriesgo to 6.1
b) Image, sound or other data types that require a neural networkgo to 6.2

5.2 ALE plot

ALE (Accumulated Local Effects) plots provide a graphical representation of the relationship between a specific feature and the output of a model. This visualisation is simple to interpret, as it shows how the average prediction changes with varying feature values. For a more detailed understanding of how a feature impacts the predictions for individual inputs, ICE (Individual Conditional Expectation) plots can provide more granular information (go to 5.4).

Read more about ALE plots

Python repository for ale plots

5.3 Local Feature Importance

Feature importance methods are very popular because of their simple interpretation. A high score for a feature tells us that it has influenced the model decision a lot. If the score is close to zero, the feature was irrelevant.

What is the data type of your input?

a) Tabular data like numbers and categoriesgo to 6.3
b) Image, sound or other data types that require a neural networkgo to 6.4

5.4 ICE plots

ICE (Individual Conditional Expectation) plots are a useful tool for understanding how a specific feature impacts the output of a model on an individual level. Each line in an ICE plot shows how the model prediction for one input would be influenced when the feature value changes. ICE plots can be combined with ALE plots (go to 5.2)

Read more about ICE plots

ICE plots are integrated into scikit-learn

5.5 Counterfactuals for Tabular Data

There are many algorithms to generate counterfactuals on tabular data. A very detailed overview and discussion can be found in this paper. A solid choice is DiCE, which has a very good python library.

5.6 Projected Gradient Descent (PGD)

PGD is transforming an image step by step into a similar one with a different prediction class that is desired. Like this you can see how close the image is to being predicted as the other class and what is missing for this prediction change. It is also possible to understand how the AI model imagines an image of a class to look like. With Auto PGD (2020) and the Auto-Frank-Wolfe method (2022) there are improved versions of PGD that show impressive results. There is a python repository for Auto-PGD. When you have access to the training code of your neural network it is also possible to integrate one of the methods there.

6.1 Global Feature Importance Scores for Tabular Data

Permutation Feature Importance is comparing the normal model mean error with the mean error on the data, where the information of a feature is removed by permuting the feature values randomly. If the error increases, the feature is very important for the model prediction and if the error does not change, the feature might be irrelevant for the model. The method is implemented for scikit-learn models.

Paper about Permutation Feature Importance

ANOVA is a method to measure for how much variance each feature is responsible in the model prediction. It can also be applied to interactions between different features. An additional benefit is that it tests if a certain feature is significant for the model. ANOVA does not scale as well as Permutation Feature Importance on a big number of features, which makes it more suitable for smaller models.

Read more

scikit-learn function

SHAP is a method to explain the feature importance for individual decisions, but it can be extended to global feature importance scores by running it on the complete dataset and adding the absolute values up. (See 6.3 for more information)

6.2 Global Feature Importance for Images

The influence of pixel areas on the model decision can't be explained in general since images in a dataset can vary a lot. Instead, it is possible to analyse a set of representative images locally. (Go to 6.4)

6.3 Local Feature Importance Scores for Tabular Data

SHAP is extremely popular, because its explanation gives more information than other feature importance scores would give. For an input the generated shapley value tells you, how much the prediction was increased or decreased compared to the average prediction based on the concrete feature value. For an algorithm that estimates a price for a used car, SHAP could tell you that the prediction dropped by 7.600€, because your car is 12 years old, while average cars in the dataset are only five years old or that the price goes up by 5.200€, because you never had an accident. One drawback of SHAP is its long runtime, which can be problematic, if you want an immediate analysis for a made prediction. Optimised hardware or a specialised variant for tree-based algorithms can accelerate the runtime and solve this problem.

Read more

SHAP has a very good python package

Causal Explanation (CXPlain) can be a faster alternative to SHAP. Initially the implementation can be more work because a model needs to be trained once to estimate the feature importance. When the model is trained it has a very fast runtime which makes it possibles to get an explanation for the prediction within seconds. Additionally, confidence intervals can evaluate the trustworthiness of the generated explanations. There is a python repositorybut it can be necessary to customize the feature importance estimation model, to obtain more accurate predictions.

Paper

6.4 Saliency maps

Feature importance is also applicable to images. A saliency map is showing as a heatmap which pixels or image areas were the most important ones for the model decision. On medical images a saliency map can highlight the important areas for the diagnosis, so a physician knows immediately, where to look at and notices a bad diagnosis by an AI if unreasonable spots are influential on the decision.

This library implements the following and more methods for tensorflow

GradCAM is a fast and therefore popular method. The saliency map produced is highlighting coarse areas, so details like single pixels are not evaluated individually.

PyTorch implementation

SmoothGrad is working on a pixel level and more robust against noise, so its explanations are more trustworthy than the ones generated by GradCAM. The disadvantage is, that its runtime takes much longer than for GradCAM.

Read more about GradCAM and SmoothGrad

Layer-Wise Relevance Propagation (LRP) is customising the backpropagation process, which is the core of GradCAM and SmoothGrad. This can improve the results for your specific model but is at the same time additional complex finetuning work.

A demo of LRP

Read more about LRP

PyTorch Implementation

Meaningful Perturbation is a method to generate high quality saliency maps but at the cost of a long runtime.

Paper

Video about Meaningful Perturbation

PyTorch implementation

AtMan is a saliency method for transformer-based model architectures and optimised for large models, where other methods would require too much memory or computing resources.

Paper

All saliency map methods for images could also be adapted for other data types. For example, it is possible to mark important words in text data.

Example Use Case: Lead Scoring

A sales team has deployed a lead scoring model to pinpoint prospective customers with a greater probability of converting into sales. Nevertheless, to enhance their correspondence with the leads, the team aims to comprehend why the model has singled out particular leads as high-value prospects. Consequently, they intend to utilise the XAI-Navigator framework to achieve this objective:

Question 1: Do you want to analyse an existing model?

Answer: a) The model already exists.

Question 2.1: Do you want to understand the model in general, or do you want explanations for specific decisions?

Answer: b) We want individual information about each lead, so we want to explain specific decisions.

Question 3.2: What do you want to learn about specific decisions?

Answer: a) Knowing the main factors that increase the scores for a lead is valuable, because they can be addressed in the initial communication. That is the reason why we want to understand the influence of input features.

Question 4.5: How much do you want to know about the influence of features?

Answer: a) Knowing what the most important input features are is sufficient for targeting them. We don't need to understand exactly how they would influence the model, because that is intuitively clear. A score for the feature importance is sufficient.

Question 5.3: What is the data type of your input?

Answer: a) The model is based on tabular data sources.

Solution: 6.3 SHAP

The model is running infrequently, and it would be okay if the XAI method would automatically run for some minutes before giving out the explanations. Without a hard constraint on the runtime SHAP would be a reasonable explainability method. The three features, that increased the score the most, can be displayed. With this information the sales team knows the reasons why this lead might be interested in the offered products. It could even be considered to display the three features that lowered the score the most, because it shows up potential problems.

Conclusion

In the world of AI, a plethora of explainability methods exists for model interpretation. Navigating this vast landscape in search of the right approach can be a daunting task, as the most popular algorithms may not necessarily address the unique nuances of your specific problem. To make an informed choice, it's crucial to keep your specific use case at the forefront. We hope that our XAI Navigator can assist you in maintaining a clear perspective on the myriad options available while tailoring your focus to your specific needs.

Author

Daniel Kaulbach

Daniel Kaulbach has been a data scientist at [at] since 2021, specialising in deep learning and XAI. He also heads the "KI Wissen" research project at [at], which aims to increase safety in autonomous driving using XAI. In his spare time, Daniel enjoys playing basketball.

0 Kommentare