aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-29 15:32:19 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-08-20 18:07:06 +0100
commitacb550f4410ae861e53cae27a9feb4b11d45769f (patch)
treeae2f4ec558c2cdf1afa020b80a09d7ab4be5ef6d /verif
parent68e7aee65bda5ac03fa7def753b7dc7462554793 (diff)
downloadreference_model-acb550f4410ae861e53cae27a9feb4b11d45769f.tar.gz
Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()
- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
Diffstat (limited to 'verif')
-rw-r--r--verif/tosa_ref_run.py39
-rw-r--r--verif/tosa_serializer.py16
-rw-r--r--verif/tosa_test_gen.py105
-rw-r--r--verif/tosa_test_runner.py4
4 files changed, 113 insertions, 51 deletions
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py
index 098f39b..499513b 100644
--- a/verif/tosa_ref_run.py
+++ b/verif/tosa_ref_run.py
@@ -1,5 +1,3 @@
-import os
-
# Copyright (c) 2020-2021, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,9 +16,17 @@ import os
import json
import shlex
import subprocess
+from enum import Enum, IntEnum, unique
from tosa_test_runner import TosaTestRunner, run_sh_command
+@unique
+class TosaReturnCode(IntEnum):
+ VALID = 0
+ UNPREDICTABLE = 1
+ ERROR = 2
+
+
class TosaRefRunner(TosaTestRunner):
def __init__(self, args, runnerArgs, testDir):
super().__init__(args, runnerArgs, testDir)
@@ -41,18 +47,29 @@ class TosaRefRunner(TosaTestRunner):
if args.ref_intermediates:
ref_cmd.extend(["-Ddump_intermediates=1"])
- expectedFailure = self.testDesc["expected_failure"]
+ expectedReturnCode = self.testDesc["expected_return_code"]
try:
- run_sh_command(self.args, ref_cmd)
- if expectedFailure:
- result = TosaTestRunner.Result.UNEXPECTED_PASS
+ rc = run_sh_command(self.args, ref_cmd)
+ if rc == TosaReturnCode.VALID:
+ if expectedReturnCode == TosaReturnCode.VALID:
+ result = TosaTestRunner.Result.EXPECTED_PASS
+ else:
+ result = TosaTestRunner.Result.UNEXPECTED_PASS
+ elif rc == TosaReturnCode.ERROR:
+ if expectedReturnCode == TosaReturnCode.ERROR:
+ result = TosaTestRunner.Result.EXPECTED_FAILURE
+ else:
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+ elif rc == TosaReturnCode.UNPREDICTABLE:
+ if expectedReturnCode == TosaReturnCode.UNPREDICTABLE:
+ result = TosaTestRunner.Result.EXPECTED_FAILURE
+ else:
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
else:
- result = TosaTestRunner.Result.EXPECTED_PASS
+ raise Exception("Return code unknown.")
+
except Exception as e:
- if expectedFailure:
- result = TosaTestRunner.Result.EXPECTED_FAILURE
- else:
- result = TosaTestRunner.Result.UNEXPECTED_FAILURE
+ raise Exception("Runtime Error when running: {}".format(" ".join(ref_cmd)))
return result
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index b4daaad..35dd9a2 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -31,6 +31,7 @@ from tosa import (
ResizeMode,
Version,
)
+from tosa_ref_run import TosaReturnCode
# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -57,6 +58,7 @@ DTypeNames = [
ByteMask = np.uint64(0xFF)
+
def dtype_str_to_val(name):
for i in range(len(DTypeNames)):
@@ -428,10 +430,12 @@ class TosaSerializerTensor:
u8_data.extend([b0, b1, b2, b3, b4, b5])
elif self.dtype == DType.FLOAT:
for val in self.data:
- b = struct.pack('!f', val)
+ b = struct.pack("!f", val)
u8_data.extend([b[3], b[2], b[1], b[0]])
else:
- raise Exception("unsupported data type {}".format(DTypeNames[self.dtype]))
+ raise Exception(
+ "unsupported data type {}".format(DTypeNames[self.dtype])
+ )
fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
TosaTensor.TosaTensorStart(builder)
@@ -586,7 +590,7 @@ class TosaSerializer:
self.currResultIdx = 0
# Is this an illegal test that is expected to fail?
- self.expectedFailure = False
+ self.expectedReturnCode = TosaReturnCode.VALID
self.expectedFailureDesc = ""
def __str__(self):
@@ -665,9 +669,9 @@ class TosaSerializer:
op, inputs, outputs, attributes, quant_info
)
- def setExpectedFailure(self, desc="", val=True):
+ def setExpectedReturnCode(self, val, desc=""):
- self.expectedFailure = val
+ self.expectedReturnCode = val
self.expectedFailureDesc = desc
def serialize(self):
@@ -719,7 +723,7 @@ class TosaSerializer:
test_desc["ifm_file"] = ifm_file
test_desc["ofm_name"] = ofm_name
test_desc["ofm_file"] = ofm_file
- test_desc["expected_failure"] = self.expectedFailure
+ test_desc["expected_return_code"] = self.expectedReturnCode
if self.expectedFailureDesc:
test_desc["expected_failure_desc"] = self.expectedFailureDesc
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index a3c6b05..efc819c 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -32,6 +32,7 @@ import math
import itertools
from enum import IntEnum, Enum, unique
+from tosa_ref_run import TosaReturnCode
# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -65,8 +66,9 @@ class TosaQuantGen:
@staticmethod
def qgUnary(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
@@ -86,8 +88,9 @@ class TosaQuantGen:
@staticmethod
def qgMatmul(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
@@ -304,13 +307,11 @@ class TosaTensorGen:
assert rank == 2
input_shape = testGen.makeShape(rank)
- filter_oc = (
- testGen.rng.integers(
- low=testGen.args.tensor_shape_range[0],
- high=testGen.args.tensor_shape_range[1],
- size=1,
- )[0]
- )
+ filter_oc = testGen.rng.integers(
+ low=testGen.args.tensor_shape_range[0],
+ high=testGen.args.tensor_shape_range[1],
+ size=1,
+ )[0]
filter_shape = np.asarray([filter_oc, input_shape[1]])
bias_shape = np.asarray([filter_oc])
@@ -734,7 +735,10 @@ class TosaArgGen:
random_permutations = testGen.rng.permutation(permutations)
# Create list of required amount of permutations
- arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
+ arg_list = [
+ ("perm{}".format(p), [random_permutations[p].tolist()])
+ for p in range(limit)
+ ]
return arg_list
@staticmethod
@@ -1154,7 +1158,7 @@ class TosaTestGen:
def build_table(self, op, a):
# Constant size depending on type, random values
if a.dtype == DType.INT16:
- table_dtype = DType.INT16
+ table_dtype = DType.INT16
table_arr = self.getRandTensor([513], table_dtype)
else:
assert a.dtype == DType.INT8
@@ -1497,7 +1501,7 @@ class TosaTestGen:
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
in_type_width = in_type_width + 1
- elif val.dtype == DType.UINT8:
+ elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
in_type_width = in_type_width + 1
else:
@@ -1536,7 +1540,9 @@ class TosaTestGen:
scale_arr[i], scale32
)
if shift_arr[i] < 2 or shift_arr[i] > 62:
- self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
+ self.ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value"
+ )
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
@@ -1710,14 +1716,21 @@ class TosaTestGen:
# Filter out the rank?
if rankFilter is not None and r not in rankFilter:
continue
- if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
+ if (
+ rankFilter is None
+ and shapeFilter[0] is None
+ and r not in default_test_rank_range
+ ):
continue
for t in op["types"]:
# Filter tests based on dtype?
if dtypeFilter is not None:
- if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
+ if not (
+ t in dtypeFilter
+ or (isinstance(t, list) and t[0] in dtypeFilter)
+ ):
continue
# Create the placeholder and const tensors
@@ -2660,7 +2673,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
@@ -2700,7 +2715,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
@@ -2725,7 +2742,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
return ser.addOutput(ofm_shape, ifm.dtype)
@@ -2889,39 +2908,59 @@ class OutputShaper:
if input_dtype == DType.FLOAT:
if stride_fp[0] <= 0 or stride_fp[1] <= 0:
- ser.setExpectedFailure(True, "Negative or zero stride")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Negative or zero stride"
+ )
else:
if stride[0] <= 0 or stride[1] <= 0:
- ser.setExpectedFailure(True, "Negative or zero stride")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Negative or zero stride"
+ )
if mode == ResizeMode.BILINEAR:
if input_dtype == DType.INT8:
if output_dtype != DType.INT32:
- ser.setExpectedFailure(True, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.INT16:
if output_dtype != DType.INT48:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid input data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid input data type"
+ )
elif mode == ResizeMode.NEAREST:
if input_dtype == DType.INT8:
if output_dtype != DType.INT8:
- ser.setExpectedFailure(True, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.INT16:
if output_dtype != DType.INT16:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid input data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid input data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid resize mode")
+ ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode")
return ser.addOutput(output_dims, output_dtype)
@@ -2941,6 +2980,8 @@ class OutputShaper:
raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
if output_shape[1] <= 0 or output_shape[2] <= 0:
- ser.setExpectedFailure(True, "Negative output shape")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Negative output shape"
+ )
return ser.addOutput(output_shape, out_dtype)
diff --git a/verif/tosa_test_runner.py b/verif/tosa_test_runner.py
index 82d447e..e8f921d 100644
--- a/verif/tosa_test_runner.py
+++ b/verif/tosa_test_runner.py
@@ -42,8 +42,8 @@ def run_sh_command(args, full_cmd, capture_output=False):
return (rc.stdout, rc.stderr)
else:
rc = subprocess.run(full_cmd)
- if rc.returncode != 0:
- raise Exception("Error running command: {}".format(" ".join(full_cmd_esc)))
+
+ return rc.returncode
class TosaTestRunner: