aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
blob: 694837878c400419330b38734e15112dd7bed566 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""TOSA result checker script."""
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
from enum import IntEnum
from enum import unique
from pathlib import Path

import numpy as np
from checker.color_print import LogColors
from checker.color_print import print_color
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
from schemavalidation.schemavalidation import TestDescSchemaValidator


@unique
class TestResult(IntEnum):
    """Test result values."""

    # Note: PASS must be 0 for command line return success
    PASS = 0
    MISSING_FILE = 1
    INCORRECT_FORMAT = 2
    MISMATCH = 3
    INTERNAL_ERROR = 4


TestResultErrorStr = [
    "",
    "Missing file",
    "Incorrect format",
    "Mismatch",
    "Internal error",
]
##################################

DEFAULT_FP_TOLERANCE = 1e-3
result_printing = True


def set_print_result(enabled):
    """Set whether to print out or not."""
    global result_printing
    result_printing = enabled


def _print_result(color, msg):
    """Print out result."""
    global result_printing
    if result_printing:
        print_color(color, msg)


def compliance_check(
    imp_result_path,
    ref_result_path,
    bnd_result_path,
    test_name,
    compliance_config,
    ofm_name,
    verify_lib_path,
):
    if verify_lib_path is None:
        error = "Please supply --verify-lib-path"
    else:
        error = None
        try:
            vlib = VerifierLibrary(verify_lib_path)
        except VerifierError as e:
            error = str(e)

    if error is not None:
        _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
        msg = f"Could not load verfier library: {error}"
        return (TestResult.INTERNAL_ERROR, 0.0, msg)

    success = vlib.verify_data(
        ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
    )
    if success:
        _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
        return (TestResult.PASS, 0.0, "")
    else:
        _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
        return (TestResult.MISMATCH, 0.0, "Non-compliance results found")


def test_check(
    ref_result_path,
    imp_result_path,
    test_name=None,
    quantize_tolerance=0,
    float_tolerance=DEFAULT_FP_TOLERANCE,
    misc_checks=[],
    test_desc=None,
    bnd_result_path=None,
    ofm_name=None,
    verify_lib_path=None,
):
    """Check if the result is the same as the expected reference."""
    if test_desc:
        # New compliance method - first get test details
        try:
            TestDescSchemaValidator().validate_config(test_desc)
        except Exception as e:
            _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
            msg = f"Incorrect test format: {e}"
            return (TestResult.INCORRECT_FORMAT, 0.0, msg)

    if test_name is None:
        test_name = "test"

    paths = [imp_result_path, ref_result_path, bnd_result_path]
    names = ["Implementation", "Reference", "Bounds"]
    arrays = [None, None, None]

    # Check the files exist and are in the right format
    for idx, path in enumerate(paths):
        name = names[idx]
        if path is None and name == "Bounds":
            # Bounds can be None - skip it
            continue
        if not path.is_file():
            _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
            msg = f"Missing {name} file: {str(path)}"
            return (TestResult.MISSING_FILE, 0.0, msg)
        try:
            arrays[idx] = np.load(path)
        except Exception as e:
            _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
            msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
            return (TestResult.INCORRECT_FORMAT, 0.0, msg)

    if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
        # Switch to using the verifier library for full compliance
        if ofm_name is None:
            ofm_name = test_desc["ofm_name"][0]
            if len(test_desc["ofm_name"]) > 1:
                _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
                msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
                return (TestResult.MISSING_FILE, 0.0, msg)

        compliance_json = test_desc["meta"]["compliance"]

        return compliance_check(
            *arrays,
            test_name,
            compliance_json,
            ofm_name,
            verify_lib_path,
        )

    # Else continue with original checking method
    test_result, reference_result, _ = arrays

    # Type comparison
    if test_result.dtype != reference_result.dtype:
        _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
        msg = "Mismatch results type: Expected {}, got {}".format(
            reference_result.dtype, test_result.dtype
        )
        return (TestResult.MISMATCH, 0.0, msg)

    # Size comparison
    # Size = 1 tensors can be equivalently represented as having rank 0 or rank
    # >= 0, allow that special case
    test_result = np.squeeze(test_result)
    reference_result = np.squeeze(reference_result)
    difference = None

    if np.shape(test_result) != np.shape(reference_result):
        _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
        msg = "Shapes mismatch: Reference {} vs {}".format(
            np.shape(test_result), np.shape(reference_result)
        )
        return (TestResult.MISMATCH, 0.0, msg)

    # Perform miscellaneous checks
    if "bf16" in misc_checks:
        # Ensure floats are valid bfloat16 values
        test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
        ref_res_is_bf16 = all(
            [float32_is_valid_bfloat16(f) for f in reference_result.flat]
        )
        if not (test_res_is_bf16 and ref_res_is_bf16):
            msg = (
                "All output values must be valid bfloat16. "
                "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
            )
            return (TestResult.INCORRECT_FORMAT, 0.0, msg)

    # for quantized test, allow +-(quantize_tolerance) error
    if reference_result.dtype in (
        np.int8,
        np.int16,
        np.int32,
        np.int64,
        np.uint8,
        np.uint16,
    ):

        if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, 0.0, "")
        else:
            tolerance = quantize_tolerance + 1
            while not np.all(
                np.absolute(reference_result - test_result) <= quantize_tolerance
            ):
                tolerance = tolerance + 1
                if tolerance > 10:
                    break

            if tolerance > 10:
                msg = "Integer result does not match and is greater than 10 difference"
            else:
                msg = (
                    "Integer result does not match but is within {} difference".format(
                        tolerance
                    )
                )
            # Fall-through to below to add failure values
            difference = reference_result - test_result

    elif reference_result.dtype == bool:
        assert test_result.dtype == bool
        # All boolean values must match, xor will show up differences
        test = np.array_equal(reference_result, test_result)
        if np.all(test):
            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, 0.0, "")
        msg = "Boolean result does not match"
        tolerance = 0.0
        difference = None
        # Fall-through to below to add failure values

    # TODO: update for fp16 tolerance
    elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
        tolerance = float_tolerance
        if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
            _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
            return (TestResult.PASS, tolerance, "")
        msg = "Float result does not match within tolerance of {}".format(tolerance)
        difference = reference_result - test_result
        # Fall-through to below to add failure values
    else:
        _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
        msg = "Unsupported results type: {}".format(reference_result.dtype)
        return (TestResult.MISMATCH, 0.0, msg)

    # Fall-through for mismatch failure to add values to msg
    _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
    np.set_printoptions(threshold=128, edgeitems=2)

    if difference is not None:
        tolerance_needed = np.amax(np.absolute(difference))
        msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)

    msg = "{}\n>> reference_result: {}\n{}".format(
        msg, reference_result.shape, reference_result
    )
    msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)

    if difference is not None:
        msg = "{}\n!! difference_result: \n{}".format(msg, difference)
    return (TestResult.MISMATCH, tolerance, msg)


def main(argv=None):
    """Check that the supplied reference and result files have the same contents."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "ref_result_path",
        type=Path,
        help="path to the reference model result file to check",
    )
    parser.add_argument(
        "imp_result_path",
        type=Path,
        help="path to the implementation result file to check",
    )
    parser.add_argument(
        "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
    )
    parser.add_argument(
        "--test-path", type=Path, help="path to the test that produced the results"
    )
    # Deprecate the incorrectly formatted option by hiding it
    parser.add_argument("--test_path", type=Path, help=argparse.SUPPRESS)
    parser.add_argument(
        "--bnd-result-path",
        type=Path,
        help="path to the reference model bounds result file for the dot product compliance check",
    )
    parser.add_argument(
        "--ofm-name",
        type=str,
        help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
    )
    parser.add_argument(
        "--verify-lib-path",
        type=Path,
        help="path to TOSA verify library",
    )
    args = parser.parse_args(argv)

    if args.test_path:
        # Get details from the test path
        test_desc_path = args.test_path / "desc.json"
        if not args.test_path.is_dir() or not test_desc_path.is_file():
            print(f"Invalid test directory {str(args.test_path)}")
            return TestResult.MISSING_FILE

        try:
            with test_desc_path.open("r") as fd:
                test_desc = json.load(fd)
        except Exception as e:
            print(f"Invalid test description file {str(test_desc_path)}: {e}")
            return TestResult.INCORRECT_FORMAT
        test_name = args.test_path.name
    else:
        test_desc = None
        test_name = None

    result, tolerance, msg = test_check(
        args.ref_result_path,
        args.imp_result_path,
        float_tolerance=args.fp_tolerance,
        test_name=test_name,
        test_desc=test_desc,
        bnd_result_path=args.bnd_result_path,
        ofm_name=args.ofm_name,
        verify_lib_path=args.verify_lib_path,
    )
    if result != TestResult.PASS:
        print(msg)

    return result


if __name__ == "__main__":
    exit(main())