Tree Model Interpretation using SHAP¶
There are many techniques, each with advantages and disadvantages that can be suitable for different situations. SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model, and is well-suited for exploring feature importances.
Pros:
- Mathematical theory behind explanation of the model.
- Very wide application and ease of use. Explanations on single sample and global level, and a number of graphs that can be very easily computed and understood.
- Feature interactions taken into account by the method.
- High computation speed, especially for the tree based models.
Cons:
- Documentation is often lacking.
- Different API when you use sklearn models e.g. RandomForestClassifier.
- Slow computation for some explainers e.g. KernelExplainer.
Let's assume we want to analyse the following model:
%%capture
!pip install probatus
import warnings
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from probatus.interpret import ShapModelInterpreter
warnings.filterwarnings("ignore")
feature_names = ["f1", "f2", "f3", "f4"]
# Prepare two samples
X, y = make_classification(n_samples=1000, n_features=4, random_state=0)
X = pd.DataFrame(X, columns=feature_names)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Prepare and fit model. Remember about class_weight="balanced" or an equivalent.
model = RandomForestClassifier(n_estimators=100, max_depth=2, random_state=0)
model = model.fit(X_train, y_train)
ShapModelInterpreter¶
The ShapModelInterpreter
class in Probatus is a convenience wrapper class that allows us to easily interpret the ML models.
Currently it supports only tree-based & linear models.
Feature importance¶
Firstly, lets compute the report presenting various properties of the model:
mean_abs_shap_value_test
- SHAP feature importance computed on the test set. It is an unbiased measurement of feature importance of the model on unseen data.mean_abs_shap_value_train
- SHAP feature importance computed on the train set. It is a biased measurement, because the model has used this data to train. However, a significant difference between this metric and themean_abs_shap_value_test
might indicate a shift in the data distribution, target distribution, or overfitting of the model.mean_shap_value_test
- This metric presents how strongly values of a given feature in the test set push the prediction towards one class or the other. A positive value indicates that this feature increases the probability of the positive class, and negative indicates that it decreases it. In the balanced setting it is typically around 0, while for imbalanced the value it has is relative to the majority class. It is crucial to compare it withmean_shap_value_train
- if it differs significantly, there is possibly a shift in data or target distribution in the test set.mean_shap_value_train
- This metric presents how strongly the values of a given feature in the train set push the prediction towards one class or the other, similarly tomean_shap_value_test
.
shap_interpreter = ShapModelInterpreter(model)
feature_importance = shap_interpreter.fit_compute(X_train, X_test, y_train, y_test, approximate=False)
feature_importance
mean_abs_shap_value_test | mean_abs_shap_value_train | mean_shap_value_test | mean_shap_value_train | |
---|---|---|---|---|
f1 | 0.315121 | 0.314689 | 0.005170 | 0.009490 |
f4 | 0.087408 | 0.090953 | 0.004380 | 0.002525 |
f3 | 0.045545 | 0.045040 | -0.007564 | 0.000847 |
f2 | 0.006017 | 0.006600 | 0.000335 | 0.000709 |
Run the following command to plot the SHAP feature importance.
ax = shap_interpreter.plot("importance")
The AUC on train and test sets is illustrated in each plot, to indicate if the model overfits. If you see that Test AUC is significantly lower than Train AUC, this is a sign that the model might be overfitting. In such cases, the interpretation of the model might be misleading. In these situations we recommend retraining the model with more regularization.
Summary plot¶
Summary plot gives you more insights into how different feature values affect the predictions made. This is a very crucial plot to make for every model. Each dot on the X-axis represents a sample in the data, and how strongly it affected the prediction (together with predictions direction). The colours of the dots present the values of that feature. For each model try to analyse this plot with Subject Matter Expert, in order to make sure that the relations that the model has learned make sense.
ax = shap_interpreter.plot("summary")
Dependence Plot¶
This plot allows you to understand how the model reacts for different feature values. You can plot it for each feature in your model, or at least the top 10 features. This can provide you with further insights on how the model uses each of the features. Moreover, one can detect anomalies, as well as the effect of the outliers on the model.
As an addition, the bottom plot presents the feature distribution histogram, and the target rate for different buckets within that feature values. This allows you to further analyse how the feature correlates with the target variable.
ax = shap_interpreter.plot("dependence", target_columns=["f1"])
Sample explanation¶
In order to explain predictions for specific samples from your test set, you can use a sample plot. For a given sample, the plot presents the force and direction of the prediction shift that each feature value causes.
ax = shap_interpreter.plot("sample", samples_index=[521, 78])
Detecting Data or Target Distribution Shift¶
Let's assume that there is a shift between the train and test data:
X_test["f1"] = X_test["f1"] - 5
X_test["f4"] = X_test["f4"] + 5
Now, we can look into how it affects the resutls:
shap_interpreter = ShapModelInterpreter(model)
feature_importance = shap_interpreter.fit_compute(X_train, X_test, y_train, y_test, approximate=False)
feature_importance
mean_abs_shap_value_test | mean_abs_shap_value_train | mean_shap_value_test | mean_shap_value_train | |
---|---|---|---|---|
f3 | 0.05316 | 0.044163 | -0.002152 | -0.003084 |
f2 | 0.00940 | 0.006514 | 0.000168 | 0.000425 |
f1 | 0.00000 | 0.316807 | 0.000000 | -0.055661 |
f4 | 0.00000 | 0.087430 | 0.000000 | -0.010959 |
In case of feature f1
and f4
the shift is indicated by the differences between train and test in mean absolute shap values, and mean shap values.
We can visualize the summary plot for these two sets:
ax_test = shap_interpreter.plot("summary", target_set="test")
ax_train = shap_interpreter.plot("summary", target_set="train")
Tips for using the interpreter¶
Before using the ShapModelInterpreter consider the following tips:
- Make sure you do not underfit or overfit the model. Underfitting will cause only the most important relations in the data to be visible, while overfitting will present relationships that do not generalize.
- Perform a feature selection process before fitting the final model. This way, it will be easier to interpret the explanation. Moreover, highly-correlated features will affect the explanation less.
- Preferably use a model that handles NaNs e.g. LightGBM or impute them beforehand using SHAP. When imputing also extract a MissingIndicator to get insights into when NaNs are meaningful for the model.
- For categorical features either use a model that handles them e.g. LightGBM, or apply One-hot encoding. Keep in mind that with One-hot encoding the importance of a categorical feature might be spread over multiple encoded features.