Skip to content

openavmkit.utilities.assertions

dfs_are_equal

dfs_are_equal(a, b, primary_key=None, allow_weak=False)

Test whether two DataFrames are equal

Parameters:

Name Type Description Default
a DataFrame

A DataFrame

required
b DataFrame

Another DataFrame

required
primary_key str

The primary key for the first DataFrame

None
allow_weak bool

Whether to ignore trivial differences (such as nominally different types for columns with otherwise identical values)

False

Returns:

Type Description
bool

Whether the two DataFrames are equal or not

Source code in openavmkit/utilities/assertions.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def dfs_are_equal(a: pd.DataFrame, b: pd.DataFrame, primary_key=None, allow_weak=False):
    """Test whether two DataFrames are equal

    Parameters
    ----------
    a : pd.DataFrame
        A DataFrame
    b : pd.DataFrame
        Another DataFrame
    primary_key : str
        The primary key for the first DataFrame
    allow_weak : bool
        Whether to ignore trivial differences (such as nominally different types for columns with otherwise identical values)

    Returns
    -------
    bool
        Whether the two DataFrames are equal or not
    """
    a = a.copy()
    b = b.copy()

    if primary_key is not None:
        a = a.sort_values(by=primary_key)
        b = b.sort_values(by=primary_key)

    # sort column names so they're in the same order:
    a = a.reindex(sorted(a.columns), axis=1)
    b = b.reindex(sorted(b.columns), axis=1)

    # ensure that the two dataframes contain the same information:
    if not a.columns.equals(b.columns):
        print(f"Columns do not match:\nA={a.columns}\nB={b.columns}")
        return False

    a_sorted_index = a.index.sort_values()
    b_sorted_index = b.index.sort_values()

    if not a_sorted_index.equals(b_sorted_index):
        if primary_key is not None:
            a_not_in_b = a[~a[primary_key].isin(b[primary_key])][primary_key].values
            b_not_in_a = b[~b[primary_key].isin(a[primary_key])][primary_key].values
            if len(a_not_in_b) > 0:
                print(f"{len(a_not_in_b)} keys in A not in B: {a_not_in_b}")
                return False
            if len(b_not_in_a) > 0:
                print(f"{len(b_not_in_a)} keys in B not in A: {b_not_in_a}")
                print(f"len(a) = {len(a)} VS len(b) = {len(b)}")
                return False
            else:
                # we're going to reindex on primary key and proceed with comparisons on that basis
                a = a.set_index(primary_key)
                b = b.set_index(primary_key)
        else:
            print("Indices do not match")
            print(a_sorted_index)
            print("VS")
            print(b_sorted_index)
            return False

    for col in a.columns:
        no_match = False
        if col == primary_key:
            # skip the primary key column:
            continue

        if not series_are_equal(a[col], b[col]):
            # try again using primary key as the index:
            if primary_key is not None:
                a = a.set_index(primary_key)
                b = b.set_index(primary_key)
                # try again:
                if not series_are_equal(a[col], b[col]):
                    no_match = True
            else:
                no_match = True

            if no_match:

                old_val = pd.get_option("display.max_columns")
                pd.set_option("display.max_columns", None)

                bad_rows_a = a[~a[col].eq(b[col])]
                bad_rows_b = b[~a[col].eq(b[col])]

                weak_fail = False

                if len(bad_rows_a) == 0 and len(bad_rows_b) == 0:
                    weak_fail = True
                    if not allow_weak:
                        print(
                            f"Column '{col}' does not match even though rows are naively equal, look:"
                        )
                        print(a[col])
                        print(b[col])
                else:
                    print(f"Column '{col}' does not match, look:")
                    # print rows that are not equal:
                    print(bad_rows_a[col])
                    print(bad_rows_b[col])

                pd.set_option("display.max_columns", old_val)

                if weak_fail and allow_weak:
                    continue

                print(f"Column '{col}' does not match for some reason.")
                return False

    return True

dicts_are_equal

dicts_are_equal(a, b)

Test whether two dictionaries are equal

Parameters:

Name Type Description Default
a dict

A dictionary

required
b dict

Another dictionary

required

Returns:

Type Description
bool

Whether the two dictionaries are equal or not

Source code in openavmkit/utilities/assertions.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def dicts_are_equal(a: dict, b: dict):
    """Test whether two dictionaries are equal

    Parameters
    ----------
    a : dict
        A dictionary
    b : dict
        Another dictionary

    Returns
    -------
    bool
        Whether the two dictionaries are equal or not
    """
    # ensure that the two dictionaries contain the same information:
    if len(a) != len(b):
        return False
    for key in a:
        if key not in b:
            return False
        entry_a = a[key]
        entry_b = b[key]
        if not objects_are_equal(entry_a, entry_b):
            return False
    return True

lists_are_equal

lists_are_equal(a, b)

Test whether two lists are equal

Parameters:

Name Type Description Default
a list

A list

required
b list

Another list

required

Returns:

Type Description
bool

Whether the two lists are equal or not

Source code in openavmkit/utilities/assertions.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def lists_are_equal(a: list, b: list):
    """Test whether two lists are equal

    Parameters
    ----------
    a : list
        A list
    b : list
        Another list

    Returns
    -------
    bool
        Whether the two lists are equal or not
    """
    # ensure that the two lists contain the same information:
    result = True
    if len(a) != len(b):
        result = False
    else:
        for i in range(len(a)):
            entry_a = a[i]
            entry_b = b[i]
            result = objects_are_equal(entry_a, entry_b)
    if not result:
        # print both lists for debugging:
        print(a)
        print(b)
        return False
    return True

objects_are_equal

objects_are_equal(a, b, epsilon=1e-06)

Test whether two objects are equal

Checks strings, dicts, lists, ints, floats, and objects

Parameters:

Name Type Description Default
a Any

A value of any type

required
b Any

Another value of any type

required
epsilon float

If the values are both floats, the maximum allowed tolerance

1e-06

Returns:

Type Description
bool

Whether the two objects are equal or not

Source code in openavmkit/utilities/assertions.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def objects_are_equal(a, b, epsilon: float = 1e-6):
    """Test whether two objects are equal

    Checks strings, dicts, lists, ints, floats, and objects

    Parameters
    ----------
    a : Any
        A value of any type
    b : Any
        Another value of any type
    epsilon : float
        If the values are both floats, the maximum allowed tolerance

    Returns
    -------
    bool
        Whether the two objects are equal or not
    """
    a_str = isinstance(a, str)
    b_str = isinstance(b, str)

    if a_str and b_str:
        return a == b

    a_dict = isinstance(a, dict)
    b_dict = isinstance(b, dict)

    if a_dict and b_dict:
        return dicts_are_equal(a, b)

    a_list = isinstance(a, list)
    b_list = isinstance(b, list)

    if a_list and b_list:
        return lists_are_equal(a, b)
    else:
        a_other = a_str or a_dict or a_list
        b_other = b_str or b_dict or b_list

        a_is_num = (not a_other) and (isinstance(a, (int, float)) or np.isreal(a))
        b_is_num = (not b_other) and (isinstance(b, (int, float)) or np.isreal(b))

        if a is None and b is None:
            return True
        elif a is None or b is None:
            return False

        if a_is_num and b_is_num:
            a_is_float = isinstance(a, float)
            b_is_float = isinstance(b, float)

            if a_is_float and b_is_float:
                a_is_nan = np.isnan(a)
                b_is_nan = np.isnan(b)

                if a_is_nan and b_is_nan:
                    return True
                if a_is_nan or b_is_nan:
                    return False

            # compare floats with epsilon:
            return abs(a - b) < epsilon

        # ensure types are the same:
        if type(a) != type(b):
            return False
        return a == b

series_are_equal

series_are_equal(a, b)

Test whether two series are equal

Parameters:

Name Type Description Default
a Series

A series

required
b Series

Another series

required

Returns:

Type Description
bool

Whether the two series are equal or not

Source code in openavmkit/utilities/assertions.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def series_are_equal(a: pd.Series, b: pd.Series):
    """Test whether two series are equal

    Parameters
    ----------
    a : pd.Series
        A series
    b : pd.Series
        Another series

    Returns
    -------
    bool
        Whether the two series are equal or not
    """
    # deal with 32-bit vs 64-bit type nonsense:

    a_type = a.dtype
    b_type = b.dtype

    a_is_float = "float" in str(a_type).lower()
    b_is_float = "float" in str(b_type).lower()

    a_is_int = "int" in str(a_type).lower()
    b_is_int = "int" in str(b_type).lower()

    if a_is_float and b_is_float:

        if a.isna().sum() != b.isna().sum():
            print(
                f"Number of NaN values do not match: a={a.isna().sum()} b={b.isna().sum()}"
            )
            return False

        a_fill_na = a.fillna(0)
        b_fill_na = b.fillna(0)

        # compare floats with epsilon:
        result = a_fill_na.subtract(b_fill_na).abs().max() < 1e-6
        if result == False:
            print(
                f"Comparing floats with epsilon:\n{a_fill_na.subtract(b_fill_na).abs().max()}"
            )
        return result

    if a_is_int and b_is_int:
        # compare integers directly:
        result = a.subtract(b).abs().max() == 0
        if result == False:
            print(f"Comparing integers directly:\n{a.subtract(b).abs().max()}")
        return result

    # ensure that the two series contain the same information:
    if not a.equals(b):

        # Check for "NONE" values in either one and replace with NaN:
        a.loc[a.isna()] = np.nan
        b.loc[b.isna()] = np.nan

        # check which values are NaN:
        a_is_nan = pd.isna(a) | a.isna() | a.isnull() | a.eq(None)
        b_is_nan = pd.isna(b) | b.isna() | b.isnull() | b.eq(None)

        # mask out the NaN values and see if those sections are equal:
        a_masked = a[~a_is_nan]
        b_masked = b[~b_is_nan]

        if not a_masked.equals(b_masked):

            # if both are datetimes:
            if pd.api.types.is_datetime64_any_dtype(
                a_masked
            ) and pd.api.types.is_datetime64_any_dtype(b_masked):
                # compare datetimes directly:
                result = a_masked.subtract(b_masked).abs().max() == pd.Timedelta(0)
                if result == False:
                    print(
                        f"Comparing datetimes directly:\n{a_masked.subtract(b_masked).abs().max()}"
                    )
                return result
            else:
                # attempt to cast both as floats and compare:
                try:
                    a_masked = a_masked.astype(float)
                    b_masked = b_masked.astype(float)
                    delta = a_masked.subtract(b_masked).abs().max()
                    result = delta < 1e-6
                    if not result:
                        print(f"Masked values are not equal, max delta = {delta}")
                        return False
                    else:
                        return True
                except ValueError:
                    print("Masked values are not equal and cannot be cast to float")

                return False
        else:
            # if the masked values are equal, then the two series are equal except for the NaN values
            # check we have an equal number of NaN values:
            if a_is_nan.sum() != b_is_nan.sum():
                print(
                    f"Number of NaN values do not match: a={a_is_nan.sum()} b={b_is_nan.sum()}"
                )
                return False
            else:
                # now check if the NaN values are in the same places:
                if a_is_nan.equals(b_is_nan):
                    return True
                else:
                    print("NaN values are in different places")
                    return False

    return True