From acb550f4410ae861e53cae27a9feb4b11d45769f Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 29 Jun 2021 15:32:19 -0700 Subject: Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF() - Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1 --- verif/tosa_serializer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'verif/tosa_serializer.py') diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py index b4daaad..35dd9a2 100644 --- a/verif/tosa_serializer.py +++ b/verif/tosa_serializer.py @@ -31,6 +31,7 @@ from tosa import ( ResizeMode, Version, ) +from tosa_ref_run import TosaReturnCode # Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH parent_dir = os.path.dirname(os.path.realpath(__file__)) @@ -57,6 +58,7 @@ DTypeNames = [ ByteMask = np.uint64(0xFF) + def dtype_str_to_val(name): for i in range(len(DTypeNames)): @@ -428,10 +430,12 @@ class TosaSerializerTensor: u8_data.extend([b0, b1, b2, b3, b4, b5]) elif self.dtype == DType.FLOAT: for val in self.data: - b = struct.pack('!f', val) + b = struct.pack("!f", val) u8_data.extend([b[3], b[2], b[1], b[0]]) else: - raise Exception("unsupported data type {}".format(DTypeNames[self.dtype])) + raise Exception( + "unsupported data type {}".format(DTypeNames[self.dtype]) + ) fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data) TosaTensor.TosaTensorStart(builder) @@ -586,7 +590,7 @@ class TosaSerializer: self.currResultIdx = 0 # Is this an illegal test that is expected to fail? - self.expectedFailure = False + self.expectedReturnCode = TosaReturnCode.VALID self.expectedFailureDesc = "" def __str__(self): @@ -665,9 +669,9 @@ class TosaSerializer: op, inputs, outputs, attributes, quant_info ) - def setExpectedFailure(self, desc="", val=True): + def setExpectedReturnCode(self, val, desc=""): - self.expectedFailure = val + self.expectedReturnCode = val self.expectedFailureDesc = desc def serialize(self): @@ -719,7 +723,7 @@ class TosaSerializer: test_desc["ifm_file"] = ifm_file test_desc["ofm_name"] = ofm_name test_desc["ofm_file"] = ofm_file - test_desc["expected_failure"] = self.expectedFailure + test_desc["expected_return_code"] = self.expectedReturnCode if self.expectedFailureDesc: test_desc["expected_failure_desc"] = self.expectedFailureDesc -- cgit v1.2.1