Skip to content

openavmkit.shap_analysis

compute_shap

compute_shap(smr, plot=False, title='')

Compute SHAP values for a given model and dataset.

Parameters:

Name Type Description Default
smr SingleModelResults

The SingleModelResults object containing the fitted model and data splits.

required
plot bool

If True, generate and display a SHAP summary plot. Defaults to False.

False
title str

Title to use for the SHAP plot if plot is True. Defaults to an empty string.

''

Returns:

Type Description
ndarray

SHAP values array for the evaluation dataset.

Source code in openavmkit/shap_analysis.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def compute_shap(smr: SingleModelResults, plot: bool = False, title: str = ""):
    """
    Compute SHAP values for a given model and dataset.

    Parameters
    ----------
    smr : SingleModelResults
        The SingleModelResults object containing the fitted model and data splits.
    plot : bool, optional
        If True, generate and display a SHAP summary plot. Defaults to False.
    title : str, optional
        Title to use for the SHAP plot if `plot` is True. Defaults to an empty string.

    Returns
    -------
    np.ndarray
        SHAP values array for the evaluation dataset.
    """

    if smr.type not in ["xgboost", "catboost", "lightgbm"]:
        # SHAP is not supported for this model type
        return

    X_train = smr.ds.X_train

    shaps = _compute_shap(smr.model, X_train, X_train)

    if plot:
        plot_full_beeswarm(shaps, title=title)

plot_full_beeswarm

plot_full_beeswarm(explanation, title='SHAP Beeswarm', wrap_width=20)

Plot a full SHAP beeswarm for a tree-based model with wrapped feature names.

This function wraps long feature names, auto-scales figure size to the number of features, and renders a beeswarm plot with rotated, smaller y-axis labels.

Parameters:

Name Type Description Default
explanation Explanation

SHAP Explanation object with values, base_values, data, and feature_names.

required
title str

Title of the plot. Defaults to "SHAP Beeswarm".

'SHAP Beeswarm'
wrap_width int

Maximum character width for feature name wrapping. Defaults to 20.

20
Source code in openavmkit/shap_analysis.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def plot_full_beeswarm(
    explanation: shap.Explanation, title: str = "SHAP Beeswarm", wrap_width: int = 20
) -> None:
    """
    Plot a full SHAP beeswarm for a tree-based model with wrapped feature names.

    This function wraps long feature names, auto-scales figure size to the number of
    features, and renders a beeswarm plot with rotated, smaller y-axis labels.

    Parameters
    ----------
    explanation : shap.Explanation
        SHAP Explanation object with `values`, `base_values`, `data`, and `feature_names`.
    title : str, optional
        Title of the plot. Defaults to "SHAP Beeswarm".
    wrap_width : int, optional
        Maximum character width for feature name wrapping. Defaults to 20.
    """

    # Wrap feature names
    wrapped_names = [
        "\n".join(textwrap.wrap(fn, width=wrap_width))
        for fn in explanation.feature_names
    ]
    expl_wrapped = shap.Explanation(
        values=explanation.values,
        base_values=explanation.base_values,
        data=explanation.data,
        feature_names=wrapped_names,
    )

    # Determine figure size based on # features
    n_feats = len(wrapped_names)
    width = max(12, 0.3 * n_feats)
    height = max(6, 0.3 * n_feats)
    fig, ax = plt.subplots(figsize=(width, height), constrained_layout=True)

    # Draw the beeswarm (max_display defaults to all features here)
    shap.plots.beeswarm(expl_wrapped, max_display=n_feats, show=False)

    # Title + tweak y-labels
    ax.set_title(title)
    plt.setp(ax.get_yticklabels(), rotation=0, ha="right", fontsize=8)

    plt.show()