aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_serializer.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_serializer.py')
-rw-r--r--verif/tosa_serializer.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index b4daaad..35dd9a2 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -31,6 +31,7 @@ from tosa import (
ResizeMode,
Version,
)
+from tosa_ref_run import TosaReturnCode
# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -57,6 +58,7 @@ DTypeNames = [
ByteMask = np.uint64(0xFF)
+
def dtype_str_to_val(name):
for i in range(len(DTypeNames)):
@@ -428,10 +430,12 @@ class TosaSerializerTensor:
u8_data.extend([b0, b1, b2, b3, b4, b5])
elif self.dtype == DType.FLOAT:
for val in self.data:
- b = struct.pack('!f', val)
+ b = struct.pack("!f", val)
u8_data.extend([b[3], b[2], b[1], b[0]])
else:
- raise Exception("unsupported data type {}".format(DTypeNames[self.dtype]))
+ raise Exception(
+ "unsupported data type {}".format(DTypeNames[self.dtype])
+ )
fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
TosaTensor.TosaTensorStart(builder)
@@ -586,7 +590,7 @@ class TosaSerializer:
self.currResultIdx = 0
# Is this an illegal test that is expected to fail?
- self.expectedFailure = False
+ self.expectedReturnCode = TosaReturnCode.VALID
self.expectedFailureDesc = ""
def __str__(self):
@@ -665,9 +669,9 @@ class TosaSerializer:
op, inputs, outputs, attributes, quant_info
)
- def setExpectedFailure(self, desc="", val=True):
+ def setExpectedReturnCode(self, val, desc=""):
- self.expectedFailure = val
+ self.expectedReturnCode = val
self.expectedFailureDesc = desc
def serialize(self):
@@ -719,7 +723,7 @@ class TosaSerializer:
test_desc["ifm_file"] = ifm_file
test_desc["ofm_name"] = ofm_name
test_desc["ofm_file"] = ofm_file
- test_desc["expected_failure"] = self.expectedFailure
+ test_desc["expected_return_code"] = self.expectedReturnCode
if self.expectedFailureDesc:
test_desc["expected_failure_desc"] = self.expectedFailureDesc