Skip to content

openavmkit.shap_analysis

get_full_model_shaps

get_full_model_shaps(model, X_train, X_test, X_sales, X_univ, verbose=False)

Calculates shaps for all subsets (test, train, sales, universe) of one model run

Parameters:

Name Type Description Default
model XGBoostModel | LightGBMModel | CatBoostModel

A trained prediction model

required
X_train DataFrame

2D array of independent variables' values from the training set

required
X_test DataFrame

2D array of independent variables' values from the testing set

required
X_sales DataFrame

2D array of independent variables' values from the sales set

required
X_univ DataFrame

2D array of independent variables' values from the universe set

required
verbose bool

Whether to print verbose information. Defaults to False.

False

Returns:

Type Description
dict

A dict containing shap.Explanation objects keyed to "train", "test", "sales", and "univ"

Source code in openavmkit/shap_analysis.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def get_full_model_shaps(
    model: XGBoostModel | LightGBMModel | CatBoostModel,
    X_train: pd.DataFrame,
    X_test: pd.DataFrame,
    X_sales: pd.DataFrame,
    X_univ: pd.DataFrame,
    verbose: bool = False
):
    """
    Calculates shaps for all subsets (test, train, sales, universe) of one model run

    Parameters
    ----------
    model: XGBoostModel | LightGBMModel | CatBoostModel
        A trained prediction model
    X_train: pd.DataFrame
        2D array of independent variables' values from the training set
    X_test: pd.DataFrame
        2D array of independent variables' values from the testing set
    X_sales: pd.DataFrame
        2D array of independent variables' values from the sales set
    X_univ: pd.DataFrame
        2D array of independent variables' values from the universe set
    verbose: bool
        Whether to print verbose information. Defaults to False.

    Returns
    -------
    dict
        A dict containing shap.Explanation objects keyed to "train", "test", "sales", and "univ"

    """

    tree_explainer: shap.TreeExplainer

    approximate = True
    cat_data = model.cat_data

    model_type = ""
    if isinstance(model, XGBoostModel):
        model_type = "xgboost"
        tree_explainer = _xgboost_shap(model, X_train)
    elif isinstance(model, LightGBMModel):
        model_type = "lightgbm"
        approximate = False # approx. not supported for LightGBM
        tree_explainer = _lightgbm_shap(model, X_train)
    elif isinstance(model, CatBoostModel):
        model_type = "catboost"
        tree_explainer = _catboost_shap(model, X_train)
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    if verbose:
        print(f"Generating SHAPs...")

    shap_sales = _shap_explain(model_type, tree_explainer, X_sales, cat_data=cat_data, approximate=approximate, verbose=verbose, label="sales")
    shap_train = _shap_explain(model_type, tree_explainer, X_train, cat_data=cat_data, approximate=approximate, verbose=verbose, label="train")
    shap_test  = _shap_explain(model_type, tree_explainer, X_test,  cat_data=cat_data, approximate=approximate, verbose=verbose, label="test")
    shap_univ  = _shap_explain(model_type, tree_explainer, X_univ,  cat_data=cat_data, approximate=approximate, verbose=verbose, label="universe")

    return {
        "train": shap_train,
        "test":  shap_test,
        "sales": shap_sales,
        "univ":  shap_univ,
    }

make_shap_table

make_shap_table(expl, list_keys, list_vars, list_keys_sale=None, include_pred=True)

Convert a shap explanation into a dataframe breaking down the full contribution to value

Parameters:

Name Type Description Default
expl Explanation

Output of your _xgboost_shap (values: (n,m), base_values: scalar or (n,)).

required
list_keys list[str]

Primary keys in the same row order as X_to_explain

required
list_vars list[str]

Feature names in the order used for training (your canonical order).

required
list_keys_sale list[str] | None

Optional. Transaction keys in the same row order as X_to_explain. Default is None.

None
include_pred bool

Optional. Add a column that reconstructs the model output on the explained scale: base_value + sum(shap_values across features). Default is True.

True

Returns:

Type Description
DataFrame
Source code in openavmkit/shap_analysis.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def make_shap_table(
    expl: shap.Explanation,
    list_keys: list[str],
    list_vars: list[str],
    list_keys_sale: list[str] = None,
    include_pred: bool = True
) -> pd.DataFrame:
    """
    Convert a shap explanation into a dataframe breaking down the full contribution to value

    Parameters
    ----------
    expl : shap.Explanation
        Output of your _xgboost_shap (values: (n,m), base_values: scalar or (n,)).
    list_keys : list[str]
        Primary keys in the same row order as X_to_explain
    list_vars : list[str]
        Feature names in the order used for training (your canonical order).
    list_keys_sale : list[str] | None
        Optional. Transaction keys in the same row order as X_to_explain. Default is None.
    include_pred : bool
        Optional. Add a column that reconstructs the model output on the explained scale:
        base_value + sum(shap_values across features). Default is True.

    Returns
    -------
    pd.DataFrame
    """
    # 1) Validate / normalize SHAP values shape (expect regression/binary: (n, m))
    vals = expl.values
    if isinstance(vals, list):
        raise ValueError("Got a list of SHAP arrays (likely multiclass). Handle per-class tables separately.")
    vals = np.asarray(vals)
    if vals.ndim != 2:
        raise ValueError(f"Expected 2D SHAP values (n_samples, n_features), got shape {vals.shape}.")

    n, m = vals.shape

    # 2) Base values: scalar or per-row
    base = expl.base_values
    if np.isscalar(base):
        base_arr = np.full((n,), float(base))
    else:
        base = np.asarray(base)
        if base.ndim == 0:
            base_arr = np.full((n,), float(base))
        elif base.ndim == 1 and base.shape[0] == n:
            base_arr = base.astype(float)
        else:
            raise ValueError(f"Unexpected base_values shape {base.shape}. For multiclass, build per-class tables.")

    # 3) Build feature DF in the *training* column order
    # expl.feature_names comes from X_to_explain; align to canonical list_vars
    if expl.feature_names is None:
        # assume expl.values columns already match list_vars
        feature_cols = list_vars
    else:
        # ensure all requested vars exist
        existing = list(expl.feature_names)
        missing = [c for c in list_vars if c not in existing]
        if missing:
            raise ValueError(f"These list_vars are missing from explanation features: {missing}")
        feature_cols = list_vars  # enforce this order

    df_features = pd.DataFrame(vals, columns=expl.feature_names)
    df_features = df_features[feature_cols]  # reorder

    # 4) Keys up front (robust expansion)
    if len(list_keys) != n:
        raise ValueError(f"list_keys length {len(list_keys)} != number of rows {n}")
    if list_keys_sale is not None and len(list_keys_sale) != n:
        raise ValueError(f"list_keys_sale length {len(list_keys_sale)} != number of rows {n}")

    if list_keys_sale is not None:
        df_keys = pd.DataFrame({"key": list_keys, "key_sale": list_keys_sale})
    else:
        df_keys = pd.DataFrame({"key": list_keys})

    # 5) Base value column (between keys and features)
    df_base = pd.DataFrame({"base_value": base_arr})

    # 6) Optional reconstructed prediction on the explained scale
    # (raw margin for classifiers unless you used model_output="probability")
    if include_pred:
        pred = base_arr + df_features.sum(axis=1).to_numpy()
        df_pred = pd.DataFrame({"contribution_sum": pred})
        df = pd.concat([df_keys, df_base, df_features, df_pred], axis=1)
    else:
        df = pd.concat([df_keys, df_base, df_features], axis=1)

    return df

plot_full_beeswarm

plot_full_beeswarm(explanation, title='SHAP Beeswarm', save_path=None, save_kwargs=None, wrap_width=20)

Plot a full SHAP beeswarm for a tree-based model with wrapped feature names.

This function wraps long feature names, auto-scales figure size to the number of features, and renders a beeswarm plot with rotated, smaller y-axis labels.

Parameters:

Name Type Description Default
explanation Explanation

SHAP Explanation object with values, base_values, data, and feature_names.

required
title str

Title of the plot. Defaults to "SHAP Beeswarm".

'SHAP Beeswarm'
wrap_width int

Maximum character width for feature name wrapping. Defaults to 20.

20
save_path str

If provided, save the figure to this path (format inferred from extension). e.g., 'beeswarm.png', 'beeswarm.pdf', 'figs/beeswarm.svg'.

None
save_kwargs dict

Extra kwargs passed to fig.savefig (e.g., {'dpi': 300, 'bbox_inches': 'tight', 'transparent': True}).

None
Source code in openavmkit/shap_analysis.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def plot_full_beeswarm(
    explanation: shap.Explanation, 
    title: str = "SHAP Beeswarm", 
    save_path: str | None = None,
    save_kwargs: dict | None = None,
    wrap_width: int = 20
) -> None:
    """
    Plot a full SHAP beeswarm for a tree-based model with wrapped feature names.

    This function wraps long feature names, auto-scales figure size to the number of
    features, and renders a beeswarm plot with rotated, smaller y-axis labels.

    Parameters
    ----------
    explanation : shap.Explanation
        SHAP Explanation object with `values`, `base_values`, `data`, and `feature_names`.
    title : str, optional
        Title of the plot. Defaults to "SHAP Beeswarm".
    wrap_width : int, optional
        Maximum character width for feature name wrapping. Defaults to 20.
    save_path : str, optional
        If provided, save the figure to this path (format inferred from extension).
        e.g., 'beeswarm.png', 'beeswarm.pdf', 'figs/beeswarm.svg'.
    save_kwargs : dict, optional
        Extra kwargs passed to `fig.savefig` (e.g., {'dpi': 300, 'bbox_inches': 'tight',
        'transparent': True}).
    """
    if save_kwargs is None:
        save_kwargs = {"dpi": 300, "bbox_inches": "tight"}

    # Wrap feature names
    wrapped_names = [
        "\n".join(textwrap.wrap(fn, width=wrap_width))
        for fn in explanation.feature_names
    ]
    expl_wrapped = shap.Explanation(
        values=explanation.values,
        base_values=explanation.base_values,
        data=explanation.data,
        feature_names=wrapped_names,
    )

    # Determine figure size based on # features
    n_feats = len(wrapped_names)
    width = max(12, 0.3 * n_feats)
    height = max(6, 0.3 * n_feats)
    fig, ax = plt.subplots(figsize=(width, height), constrained_layout=True)

    # Draw the beeswarm (max_display defaults to all features here)
    shap.plots.beeswarm(expl_wrapped, max_display=n_feats, show=False)

    # Title + tweak y-labels
    ax.set_title(title)
    plt.setp(ax.get_yticklabels(), rotation=0, ha="right", fontsize=8)

    # Save if requested
    if save_path is not None:
        fig.savefig(save_path, **save_kwargs)

    plt.show()