SHAP 和 Shapley value

SHAP (SHapley Additive exPlanations) 是一个Python包, Shapley values源于博弈论的思想:

有一群不同能力的人一起打游戏,怎么根据每个人的能力决定分配奖励?

而在机器学习中,这个问题就变成了:

给定一群feature,怎么准确评估每个feature对target的贡献?

Shapley value和feature importance还不太一样,Shapley value的作用如下:

  1. 模型的全局可解释性(Global model interpretability) :以风控为例,在建模之后,会想要知道哪个feature会更多的导致拒绝申请贷款,哪个feature导致同意申请。
  2. 模型的局部可解释性(Local interpretability) :假设有一个客户被拒绝了,希望知道原因。利用Shapley value,可以对单个样本进行解释,而不需要担心关联到其他样本。

代码

这个例子用breast cancer来展示SHAP的用法

import pandas as pd
import shap
import xgboost as xgb
from sklearn.datasets import load_breast_cancer
import matplotlib.pyplot as plt
import numpy as np
breast_cancer = load_breast_cancer()

X_train = pd.DataFrame(breast_cancer['data'], columns=breast_cancer['feature_names'])
y_train = breast_cancer['target']
X_train.head()

建一个分类器

model = xgb.XGBClassifier(n_estimators=1000).fit(
    X_train, y_train
)

创建解释器,TreeExplainer是SHAP里一个特殊的类,用于解释树模型。其他的模型可以用KernelExplainer解释。

xgb_explainer = shap.TreeExplainer(
    model, X_train, feature_names=X_train.columns.tolist()
)

在GPU加速的xgboost上计算SHAP值时,需要加上pred_contribs=True

%%time

# Shap values with XGBoost core moedl
booster_xgb = model.get_booster()
shap_values_xgb = booster_xgb.predict(xgb.DMatrix(X_train, y_train), pred_contribs=True)

worst area的重要性很高

shap.summary_plot(
    shap_values_xgb, X_train, feature_names=X_train.columns, plot_type="bar"
)

这个结论和xgboost的feature importance差距不大

xgb.plot_importance(booster_xgb,height=2);

但xgboost的feature importance在不同计算方式下的差距很大

fig, axes = plt.subplots(1, 3, figsize=(20, 6))

for ax, imp_type in zip(axes.flatten(), ["weight", "gain", "cover"]):
    xgb.plot_importance(
        booster_xgb,
        ax=ax,
        importance_type=imp_type,
        title=f"Importance type - {imp_type}",
    )

plt.show();

接下来可以看一下feature对模型判断的正负影响

shap.summary_plot(shap_values_xgb, X_train, feature_names=X_train.columns);

以及相关性

shap.dependence_plot("worst area", shap_values_xgb, X_train, interaction_index=None)

可以设置interaction_index="auto",这样我们可以以和worst area交互最强的特征来作为颜色(这里交互性最强的是worst concavity)。

shap.dependence_plot("worst area", shap_values_xgb, X_train, interaction_index="auto");

接下来,可以计算一下交互特征的情况

%%time

# SHAP interactions with XGB
interactions_xgb = booster_xgb.predict(
    xgb.DMatrix(X_train, y_train), pred_interactions=True
)

计算k个最强的交互特征

def get_top_k_interactions(feature_names, shap_interactions, k):
    # Get the mean absolute contribution for each feature interaction
    aggregate_interactions = np.mean(np.abs(shap_interactions[:, :-1, :-1]), axis=0)
    interactions = []
    for i in range(aggregate_interactions.shape[0]):
        for j in range(aggregate_interactions.shape[1]):
            if j < i:
                interactions.append(
                    (
                        feature_names[i] + "-" + feature_names[j],
                        aggregate_interactions[i][j] * 2,
                    )
                )
    # sort by magnitude
    interactions.sort(key=lambda x: x[1], reverse=True)
    interaction_features, interaction_values = map(tuple, zip(*interactions))

    return interaction_features[:k], interaction_values[:k]


top_10_inter_feats, top_10_inter_vals = get_top_k_interactions(
    X_train.columns, interactions_xgb, 10
)
top_10_inter_feats
('worst concave points-worst radius',
 'worst concavity-area error',
 'worst concave points-worst area',
 'worst texture-mean concave points',
 'worst concave points-worst perimeter',
 'worst concave points-area error',
 'worst area-mean concave points',
 'worst concavity-worst area',
 'worst concavity-worst texture',
 'worst area-mean texture')

画出交互特征

def plot_interaction_pairs(pairs, values):
    plt.bar(pairs, values)
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show();
top_10_inter_feats, top_10_inter_vals = get_top_k_interactions(
    X_train.columns, interactions_xgb, 10
)

plot_interaction_pairs(top_10_inter_feats, top_10_inter_vals)

计算Shapley value:

%%time

# Recalculate SHAP values
shap_explainer_values = xgb_explainer(X_train, y_train, check_additivity=False)

解释单个样本,这个样本更接近0,原因如图:

shap.waterfall_plot(shap_explainer_values[100])

这个样本更接近1,原因如图:

shap.waterfall_plot(shap_explainer_values[200])

另一种局部可解释性的展示方法如下,是以叠加的方式展示的:

shap.initjs()  # don't forget to enable JavaScript

shap.force_plot(shap_explainer_values[100])

shap.force_plot(shap_explainer_values[200])

最后修改:2021 年 10 月 31 日 08 : 23 PM
如果觉得我的文章对你有用,请随意赞赏