Skip to content

openavmkit.utilities.plotting

get_nice_random_colors

get_nice_random_colors(n, shuffle=False, seed=1337)

Generate a list of n aesthetically pleasing and perceptually distinct colors for plotting.

Parameters:

Name Type Description Default
n int

Number of colors

required
shuffle bool

Whether to shuffle the color order to make the sequence appear more visually distinct. Default is False.

False
seed int

Random seed for determinicity.

1337

Returns:

Type Description
list[str]

List of hex color codes

Source code in openavmkit/utilities/plotting.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def get_nice_random_colors(n: int, shuffle=False, seed=1337):
    """Generate a list of n aesthetically pleasing and perceptually distinct colors for
    plotting.

    Parameters
    ----------
    n : int
        Number of colors
    shuffle : bool, optional
        Whether to shuffle the color order to make the sequence appear more visually distinct. Default is False.
    seed : int
        Random seed for determinicity.

    Returns
    -------
    list[str]
        List of hex color codes
    """
    colors = []
    golden_ratio_conjugate = 0.61803398875  # For perceptually even hue spacing
    random.seed(seed)
    h = random.random()  # Start with a random hue base

    for i in range(n):
        # Evenly spaced hue using golden ratio increment
        h = (h + golden_ratio_conjugate) % 1

        # Fix saturation and value in a pleasing range
        s = 0.65  # More muted than full saturation
        v = 0.85  # High value, good for line plots on white background

        # Convert to RGB and then to hex
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        hex_code = f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}"
        colors.append(hex_code)

    if shuffle:
        random.shuffle(colors)

    return colors

plot_bar

plot_bar(df, data_field, height=1.0, width=1.0, xlabel='', ylabel='', title='', out_file=None, style=None)

Plots a simple bar graph

Parameters:

Name Type Description Default
df DataFrame

Your dataset

required
data_field str

The field you want to graph

required
height float

The height of the bars

1.0
width float

The width of the bars

1.0
xlabel str

Label for the x-axis

''
ylabel str

Label for the y-axis

''
title str

Title for the plot

''
out_file str

If provided, writes the image to this file.

None
style dict

Style dictionary

None
Source code in openavmkit/utilities/plotting.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def plot_bar(
    df: pd.DataFrame,
    data_field: str,
    height=1.0,
    width=1.0,
    xlabel: str = "",
    ylabel: str = "",
    title: str = "",
    out_file: str = None,
    style: dict = None,
):
    """
    Plots a simple bar graph

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset
    data_field : str
        The field you want to graph
    height : float
        The height of the bars
    width : float
        The width of the bars
    xlabel : str
        Label for the x-axis
    ylabel : str
        Label for the y-axis
    title : str
        Title for the plot
    out_file : str
        If provided, writes the image to this file.
    style : dict
        Style dictionary
    """
    color = _get_color_by(df, style)

    df = df.sort_values(by=data_field, ascending=True)

    data = df[data_field]
    data = data[~data.isna()]

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.bar(data, height=height, width=width, color=color)
    if out_file is not None:
        plt.savefig(out_file)
    plt.show()

plot_histogram_df

plot_histogram_df(df, fields, xlabel='', ylabel='', title='', bins=500, x_lim=None, out_file=None)

Plots an overlaid histogram of one or more sets of values

Parameters:

Name Type Description Default
df Dataframe

Your dataset

required
fields list[str]

Field or fields you want to plot

required
xlabel str

Label for the x-axis

''
ylabel str

Label for the y-axis

''
title str

Title for the plot

''
bins int

How many bins

500
x_lim Any

x-axis limiter

None
out_file str

If provided, writes the image to this file.

None
Source code in openavmkit/utilities/plotting.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def plot_histogram_df(
    df: pd.DataFrame,
    fields: list[str],
    xlabel: str = "",
    ylabel: str = "",
    title: str = "",
    bins=500,
    x_lim=None,
    out_file: str = None,
):
    """
    Plots an overlaid histogram of one or more sets of values

    Parameters
    ----------
    df : pd.Dataframe
        Your dataset
    fields : list[str]
        Field or fields you want to plot
    xlabel : str
        Label for the x-axis
    ylabel : str
        Label for the y-axis
    title : str
        Title for the plot
    bins : int, optional
        How many bins
    x_lim : Any, optional
        x-axis limiter
    out_file : str, optional
        If provided, writes the image to this file.
    """
    entries = []
    for field in fields:
        data = df[field]
        entries.append({"data": data, "label": field, "alpha": 0.25})
    _plot_histogram_mult(entries, xlabel, ylabel, title, bins, x_lim, out_file)

plot_scatterplot

plot_scatterplot(df, x, y, xlabel=None, ylabel=None, title=None, out_file=None, style=None, best_fit_line=False, perfect_fit_line=False, metadata_field=None)

Scatterplot with inline mpld3 tooltips showing df[metadata_field].

Parameters:

Name Type Description Default
df DataFrame

Your dataset

required
x str

Variable for the x-axis

required
y str

Variable for the y-axis

required
xlabel str

Label for the x-axis

None
ylabel str

Label for the y-axis

None
title str

Title for the plot

None
out_file str

If provided, writes the image to this file.

None
style dict

Style dictionary

None
best_fit_line bool

If True, draws a best fit line.

False
perfect_fit_line bool

If True, draws a x=y line.

False
metadata_field str

If provided, shows tooltips for this field on hover.

None

Returns:

Type Description
Figure

The plot figure

Source code in openavmkit/utilities/plotting.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def plot_scatterplot(
    df,
    x: str,
    y: str,
    xlabel: str = None,
    ylabel: str = None,
    title: str = None,
    out_file: str = None,
    style: dict = None,
    best_fit_line: bool = False,
    perfect_fit_line: bool = False,
    metadata_field: str = None,
):
    """Scatterplot with inline mpld3 tooltips showing df[metadata_field].

    Parameters
    ----------
    df : pd.DataFrame
        Your dataset
    x : str
        Variable for the x-axis
    y : str
        Variable for the y-axis
    xlabel : str, optional
        Label for the x-axis
    ylabel : str, optional
        Label for the y-axis
    title : str, optional
        Title for the plot
    out_file : str, optional
        If provided, writes the image to this file.
    style : dict, optional
        Style dictionary
    best_fit_line : bool, optional
        If True, draws a best fit line.
    perfect_fit_line : bool, optional
        If True, draws a x=y line.
    metadata_field : str, optional
        If provided, shows tooltips for this field on hover.

    Returns
    -------
    Figure
        The plot figure

    """
    # 1) Defaults
    xlabel = xlabel or x
    ylabel = ylabel or y
    title = title or f"{x} vs {y}"

    # 2) New figure & axis
    fig, ax = plt.subplots()

    ax.ticklabel_format(axis='x', style='plain')
    ax.xaxis.set_major_formatter(_human_fmt(digits=3))
    ax.ticklabel_format(axis='y', style='plain')
    ax.yaxis.set_major_formatter(_human_fmt(digits=3))

    # 3) Color/style helper (your existing function)
    color = _get_color_by(df, style)

    # 4) Scatter
    sc = ax.scatter(df[x], df[y], s=4, c=color)

    legend_arr = [None]

    # 5) Optional best‐fit line
    if best_fit_line:
        results = _simple_ols(df, x, y, intercept=False)
        slope, intercept, r2 = results["slope"], results["intercept"], results["r2"]
        best_fit_label = f"Best fit line (r²={r2:.2f})"
        ax.plot(
            df[x],
            slope * df[x],
            color="red",
            alpha=0.5,
            label=best_fit_label,
        )
        legend_arr.append(best_fit_label)

    if perfect_fit_line:
        perfect_fit_label = "Perfect fit line (y=x)"
        # Add a perfect line (y=x)
        ax.plot(df[x], df[x], color="blue", alpha=0.5, label=perfect_fit_label)
        legend_arr.append(perfect_fit_label)

    ax.legend(legend_arr)

    # 6) Labels & title
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    # 7) Save if requested
    if out_file:
        fig.savefig(out_file)

    # 8) Build tooltip labels from your metadata field
    if metadata_field is not None:
        labels = df[metadata_field].astype(str).tolist()
        tooltip = plugins.PointLabelTooltip(sc, labels=labels)
        plugins.connect(fig, tooltip)

        # 9) Display the interactive HTML
        html = mpld3.fig_to_html(fig)
        display(HTML(html))
    else:
        plt.show()

    # Close the plot without showing it:
    plt.close(fig)

    return fig