diff options
author | evacha01 <evan.chandler@arm.com> | 2024-03-08 16:39:24 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-04-16 16:02:16 +0000 |
commit | 4a2051146f498cb9ec35d7213720540c5c3e81e2 (patch) | |
tree | 543000b3ef22bd587c3c7702100742e4b94eb5fb /verif | |
parent | 5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (diff) | |
download | reference_model-4a2051146f498cb9ec35d7213720540c5c3e81e2.tar.gz |
SPECIAL data gen mode for FP16 and FP32
Signed-off-by: evacha01 <evan.chandler@arm.com>
Change-Id: I5a9a1c63345bd83ca04bc6c2a99b0ef3612971ee
Diffstat (limited to 'verif')
-rw-r--r-- | verif/conformance/test_select.py | 6 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 35 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 18 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 2 | ||||
-rw-r--r-- | verif/tests/test_tosa_refmodel.py | 14 |
5 files changed, 51 insertions, 24 deletions
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py index e3a8ffb..e3f1738 100644 --- a/verif/conformance/test_select.py +++ b/verif/conformance/test_select.py @@ -259,9 +259,9 @@ class Operator: negative and "ERRORIF" in str(path) ): # Check for test set paths - match = re.match(r"(.*)_(s[0-9]+|full)", path.name) + match = re.match(r"(.*)_(s[0-9]+|full|fs)", path.name) if match: - if match.group(2) in ["s0", "full"]: + if match.group(2) in ["s0", "full", "fs"]: # Only return the truncated test name # of the first test of a set, and for full tests yield path.with_name(match.group(1)) @@ -317,7 +317,7 @@ class Operator: def _get_extra_test_paths(path): """Expand a path to find extra tests.""" paths = [] - for suffix in ["full"]: + for suffix in ["full", "fs"]: suffix_path = path.with_name(f"{path.name}_{suffix}") if suffix_path.exists(): paths.append(suffix_path) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 8d6c8d7..5957a33 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -264,6 +264,9 @@ class TosaTensorGen: return [[]] * num_shapes shape = testGen.makeShape(rng, rank) + # Do not broadcast for some tests + if error_name is None and rng.randInt(high=100) < 10: + return [shape] * num_shapes shape_list = [] # Choose any one of the inputs to broadcast @@ -785,6 +788,10 @@ class TosaTensorValuesGen: "tensors": {}, } dg_tens_meta = tens_data["tensors"] + + fp_special_info = {} + fp_special_info["start_idx"] = int(rng.randInt()) + for idx, shape in enumerate(shapeList): tens_meta = {} @@ -858,6 +865,8 @@ class TosaTensorValuesGen: rng.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"]) ) tens_meta["full_range_info"] = info + elif dg_type == gtu.DataGenType.FP_SPECIAL: + tens_meta["fp_special_info"] = fp_special_info else: # TODO - other data gen type assert False, "TODO: support other data gen types" @@ -1862,16 +1871,12 @@ class TosaArgGen: for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: gen_args_dict = args_dict.copy() + # Only create one test by default - no sets of tests + num_test_sets = 0 + if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: - num_test_sets = ( - args_dict["num_test_sets"] - if "num_test_sets" in args_dict - else 0 - ) - else: - # Add single test for pseudo random - num_test_sets = 0 + num_test_sets = args_dict.get("num_test_sets", 0) elif dg_type == gtu.DataGenType.DOT_PRODUCT: # Extra tests for each dot product test set @@ -1900,13 +1905,23 @@ class TosaArgGen: f"Skipping {opName}{shape_info} as tensor data size too small for full range of values {tensor_size} < {gtu.DTYPE_ATTRIBUTES[dtype]['fullset']}" ) continue - # Large enough tensor data size for full range, add a single test - num_test_sets = 0 + # Large enough tensor data size for full range, add full test arg_str = f"{arg_str}_full" if arg_str else "full" gen_args_dict["tags"] = args_dict.get("tags", []) + [ "non_finite_fp_data" ] + elif dg_type == gtu.DataGenType.FP_SPECIAL: + shapes_set = {tuple(x) for x in shapeList} + if len(shapes_set) != 1: + logger.info( + f"Changing {opName} input shapes {shapes_set} - broadcasting incompatable with special test" + ) + shapeList = [np.int32(np.broadcast_shapes(*shapeList))] * len( + shapeList + ) + arg_str = f"{arg_str}_fs" if arg_str else "fs" + gen_args_dict["dg_type"] = dg_type if num_test_sets > 0: for s in range(0, num_test_sets): diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 38ab3f4..40788a2 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -268,7 +268,7 @@ class TosaTestGen: if "ksb" in argsDict else int(argsDict["ks"]), } - elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL: + elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: mode = gtu.ComplianceMode.FP_SPECIAL elif "compliance" in op and "ulp" in op["compliance"]: mode = gtu.ComplianceMode.ULP @@ -3352,7 +3352,11 @@ class TosaTestGen: DType.FP32: (gtu.DataGenType.DOT_PRODUCT,), } EW_UNARY_DATAGEN = { - DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE) + DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FULL_RANGE), + } + PR_FS_DATAGEN = { + DType.FP16: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL), + DType.FP32: (gtu.DataGenType.PSEUDO_RANDOM, gtu.DataGenType.FP_SPECIAL), } TOSA_OP_LIST = { @@ -3716,7 +3720,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": PSEUDO_RANDOM_DATAGEN, + "data_gen": PR_FS_DATAGEN, "compliance": {"ulp": 0.5}, }, "arithmetic_right_shift": { @@ -3938,7 +3942,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": PSEUDO_RANDOM_DATAGEN, + "data_gen": PR_FS_DATAGEN, }, "minimum": { "op": Op.MINIMUM, @@ -4330,7 +4334,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": PSEUDO_RANDOM_DATAGEN, + "data_gen": PR_FS_DATAGEN, }, "greater_equal": { "op": Op.GREATER_EQUAL, @@ -4351,7 +4355,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": PSEUDO_RANDOM_DATAGEN, + "data_gen": PR_FS_DATAGEN, }, "greater": { "op": Op.GREATER, @@ -4372,7 +4376,7 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), - "data_gen": PSEUDO_RANDOM_DATAGEN, + "data_gen": PR_FS_DATAGEN, }, # Reduction operators "reduce_all": { diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index a8e321e..478190d 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -55,7 +55,7 @@ class DataGenType(IntEnum): DOT_PRODUCT = 1 BOUNDARY = 2 FULL_RANGE = 3 - SPECIAL = 4 + FP_SPECIAL = 4 FIXED_DATA = 5 diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py index 24ee9e2..bb52a86 100644 --- a/verif/tests/test_tosa_refmodel.py +++ b/verif/tests/test_tosa_refmodel.py @@ -1,5 +1,5 @@ """Tests for tosa_reference_model.""" -# Copyright (c) 2022-2023, ARM Limited. +# Copyright (c) 2022-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import json import re @@ -134,9 +134,10 @@ class BuildTosaTest: # Tests - op_name, ref_model_type, num_expected_tests +# FP Special datagen adds a second expected test to FP16 and FP32 tests for OPs it is added to TEST_PARAMS = [ ("add", "int32", 1), - ("add", "fp32", 1), + ("add", "fp32", 2), ("abs", "int32", 1), ("abs", "fp32", 1), ("abs", "fp16", 1), @@ -223,13 +224,20 @@ def test_refmodel_simple_op(tosaTest): assert const_file.is_file() consts.append(np.load(str(const_file))) + # Check if the data is from FP special datagen which can give invalid results + fp_special_data = test_dir.match("*_fs") + # Perform Numpy operation if op_name == "abs": assert len(tensors) == 1 result = np.abs(tensors[0]) elif op_name == "add": assert len(tensors) == 2 - result = np.add(tensors[0], tensors[1]) + if fp_special_data: + with np.errstate(invalid="ignore"): + result = np.add(tensors[0], tensors[1]) + else: + result = np.add(tensors[0], tensors[1]) elif op_name == "concat": assert len(consts) == 1 # Get axis from test directory name |