aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-07 16:09:09 +0000
committerEric Kunze <eric.kunze@arm.com>2024-02-08 21:06:36 +0000
commit01e1c1c7f965ceb07e78a3b1ad063161c0f47b94 (patch)
tree10235a2a231f5366572577462ca4d1902f9a098c
parent59d8f50f5b3399a6255643aad0e5857e30370761 (diff)
downloadreference_model-01e1c1c7f965ceb07e78a3b1ad063161c0f47b94.tar.gz
Improve Avg_Pool2D ErrorIf Testing
* Add test for invalid accumulator dtype Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I506e2047623372670b82db6e9c0010fa89802851
-rw-r--r--verif/generator/tosa_arg_gen.py3
-rw-r--r--verif/generator/tosa_error_if.py44
-rw-r--r--verif/generator/tosa_test_gen.py2
3 files changed, 49 insertions, 0 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4630f35..33e74b5 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2485,6 +2485,9 @@ class TosaArgGen:
# incorrect input data-type
accum_dtypes = [DType.INT32]
+ if error_name == ErrorIf.WrongAccumulatorType:
+ accum_dtypes = list(gtu.usableDTypes(excludes=accum_dtypes))
+
if not test_level8k:
if testGen.args.oversize:
# add some oversize argument values
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 5fd647a..9a88acb 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -84,6 +84,7 @@ class ErrorIf(object):
ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
BroadcastShapesMismatch = "BroadcastShapesMismatch"
+ WrongAccumulatorType = "WrongAccumulatorType"
class TosaErrorIfArgGen:
@@ -2580,6 +2581,49 @@ class TosaErrorValidator:
}
return info_dict
+ def evWrongAccumulatorType(check=False, **kwargs):
+ error_name = ErrorIf.WrongAccumulatorType
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "An unsupported accumulator data type was requested"
+
+ if check:
+ op = kwargs["op"]
+ input_dtype = kwargs["input_dtype"]
+ accum_dtype = kwargs["accum_dtype"]
+ if op["op"] == Op.AVG_POOL2D:
+ if (
+ input_dtype
+ in (
+ DType.INT8,
+ DType.INT16,
+ )
+ and accum_dtype != DType.INT32
+ ):
+ error_result = True
+ elif (
+ input_dtype
+ in (
+ DType.FP32,
+ DType.BF16,
+ )
+ and accum_dtype != DType.FP32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and accum_dtype not in (
+ DType.FP16,
+ DType.FP32,
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index f5eca18..2d471c0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -835,6 +835,7 @@ class TosaTestGen:
input_dtype=input.dtype,
output_shape=result_tensor.shape,
output_dtype=result_tensor.dtype,
+ accum_dtype=accum_dtype,
kernel=kernel,
stride=stride,
pad=pad,
@@ -3218,6 +3219,7 @@ class TosaTestGen:
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),