aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
blob: 1169a9554e3e7af9f75efca4f5ac1a2d3faa7faf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""TOSA result checker script."""
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
from enum import Enum
from enum import IntEnum
from enum import unique
from pathlib import Path

import numpy as np
from generator.tosa_utils import float32_is_valid_bfloat16

##################################
color_printing = True


@unique
class LogColors(Enum):
    """Shell escape sequence colors for logging."""

    NONE = "\u001b[0m"
    GREEN = "\u001b[32;1m"
    RED = "\u001b[31;1m"
    YELLOW = "\u001b[33;1m"
    BOLD_WHITE = "\u001b[1m"


def set_print_in_color(enabled):
    """Set color printing to enabled or disabled."""
    global color_printing
    color_printing = enabled


def print_color(color, msg):
    """Print color status messages if enabled."""
    global color_printing
    if not color_printing:
        print(msg)
    else:
        print("{}{}{}".format(color.value, msg, LogColors.NONE.value))


@unique
class TestResult(IntEnum):
    """Test result values."""

    # Note: PASS must be 0 for command line return success
    PASS = 0
    MISSING_FILE = 1
    INCORRECT_FORMAT = 2
    MISMATCH = 3
    INTERNAL_ERROR = 4


TestResultErrorStr = [
    "",
    "Missing file",
    "Incorrect format",
    "Mismatch",
    "Internal error",
]
##################################

DEFAULT_FP_TOLERANCE = 1e-3


def test_check(
    reference_path,
    result_path,
    test_name="test",
    quantize_tolerance=0,
    float_tolerance=DEFAULT_FP_TOLERANCE,
    misc_checks=[],
):
    """Check if the result is the same as the expected reference."""
    if not reference_path.is_file():
        print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
        msg = "Missing reference file: {}".format(reference_path)
        return (TestResult.MISSING_FILE, 0.0, msg)
    if not result_path.is_file():
        print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
        msg = "Missing result file: {}".format(result_path)
        return (TestResult.MISSING_FILE, 0.0, msg)

    try:
        test_result = np.load(result_path)
    except Exception as e:
        print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
            result_path, e
        )
        return (TestResult.INCORRECT_FORMAT, 0.0, msg)
    try:
        reference_result = np.load(reference_path)
    except Exception as e:
        print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
        msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
            reference_path, e
        )
        return (TestResult.INCORRECT_FORMAT, 0.0, msg)

    # Type comparison
    if test_result.dtype != reference_result.dtype:
        print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
        msg = "Mismatch results type: Expected {}, got {}".format(
            reference_result.dtype, test_result.dtype
        )
        return (TestResult.MISMATCH, 0.0, msg)

    # Size comparison
    # Size = 1 tensors can be equivalently represented as having rank 0 or rank
    # >= 0, allow that special case
    test_result = np.squeeze(test_result)
    reference_result = np.squeeze(reference_result)
    difference = None

    if np.shape(test_result) != np.shape(reference_result):
        print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
        msg = "Shapes mismatch: Reference {} vs {}".format(
            np.shape(test_result), np.shape(reference_result)
        )
        return (TestResult.MISMATCH, 0.0, msg)

    # Perform miscellaneous checks
    if "bf16" in misc_checks:
        # Ensure floats are valid bfloat16 values
        test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
        ref_res_is_bf16 = all(
            [float32_is_valid_bfloat16(f) for f in reference_result.flat]
        )
        if not (test_res_is_bf16 and ref_res_is_bf16):
            msg = (
                "All output values must be valid bfloat16. "
                "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
            )
            return (TestResult.INCORRECT_FORMAT, 0.0, msg)

    # for quantized test, allow +-(quantize_tolerance) error
    if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:

        if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, 0.0, "")
        else:
            tolerance = quantize_tolerance + 1
            while not np.all(
                np.absolute(reference_result - test_result) <= quantize_tolerance
            ):
                tolerance = tolerance + 1
                if tolerance > 10:
                    break

            if tolerance > 10:
                msg = "Integer result does not match and is greater than 10 difference"
            else:
                msg = (
                    "Integer result does not match but is within {} difference".format(
                        tolerance
                    )
                )
            # Fall-through to below to add failure values
            difference = reference_result - test_result

    elif reference_result.dtype == bool:
        assert test_result.dtype == bool
        # All boolean values must match, xor will show up differences
        test = np.array_equal(reference_result, test_result)
        if np.all(test):
            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, 0.0, "")
        msg = "Boolean result does not match"
        tolerance = 0.0
        difference = None
        # Fall-through to below to add failure values

    # TODO: update for fp16 tolerance
    elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
        tolerance = float_tolerance
        if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
            print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, tolerance, "")
        msg = "Float result does not match within tolerance of {}".format(tolerance)
        difference = reference_result - test_result
        # Fall-through to below to add failure values
    else:
        print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
        msg = "Unsupported results type: {}".format(reference_result.dtype)
        return (TestResult.MISMATCH, 0.0, msg)

    # Fall-through for mismatch failure to add values to msg
    print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
    np.set_printoptions(threshold=128, edgeitems=2)

    if difference is not None:
        tolerance_needed = np.amax(np.absolute(difference))
        msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)

    msg = "{}\n>> reference_result: {}\n{}".format(
        msg, reference_result.shape, reference_result
    )
    msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)

    if difference is not None:
        msg = "{}\n!! difference_result: \n{}".format(msg, difference)
    return (TestResult.MISMATCH, tolerance, msg)


def main(argv=None):
    """Check that the supplied reference and result files have the same contents."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "reference_path", type=Path, help="the path to the reference file to test"
    )
    parser.add_argument(
        "result_path", type=Path, help="the path to the result file to test"
    )
    parser.add_argument(
        "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
    )
    args = parser.parse_args(argv)

    result, tolerance, msg = test_check(
        args.reference_path, args.result_path, float_tolerance=args.fp_tolerance
    )
    if result != TestResult.PASS:
        print(msg)

    return result


if __name__ == "__main__":
    exit(main())