aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2024-01-26 16:56:55 +0000
committerJerry Ge <jerry.ge@arm.com>2024-02-22 18:22:41 +0000
commit20ab3df3d3100af68c47825846eee31925ff592d (patch)
treeb032f4e6cbab7edbe5b3a02fadd1621a3e51216f /verif
parentc7bfa58c76e73aac772f714d8ae04cc875715689 (diff)
downloadreference_model-20ab3df3d3100af68c47825846eee31925ff592d.tar.gz
Save Int16/UINT16 test outputs to native dtypes
* Int16/UInt16 reference outputs were previously saved to INT32 * Save those in their native dtypes and updated other affected code Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I0c3b7fba096a8cb1ddabef20ad13498b8f46d36f
Diffstat (limited to 'verif')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py6
-rw-r--r--verif/generator/tosa_test_gen.py4
2 files changed, 9 insertions, 1 deletions
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index ce9b253..56daa51 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -695,7 +695,11 @@ def run_test(args, test_path, framework):
tf_result = tf_result.astype(np.int8)
elif tf_result.dtype == np.uint8:
tf_result = tf_result.astype(np.uint8)
- elif tf_result.dtype == np.int16 or tf_result.dtype == np.int64:
+ elif tf_result.dtype == np.int16:
+ tf_result = tf_result.astype(np.int16)
+ elif tf_result.dtype == np.uint16:
+ tf_result = tf_result.astype(np.uint16)
+ elif tf_result.dtype == np.int64:
tf_result = tf_result.astype(np.int32)
# For now, search for the first output from ref_model
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bc931dc..8440853 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -195,6 +195,10 @@ class TosaTestGen:
return np.int8(self.rng.integers(low=low, high=high, size=shape))
elif dtype == DType.UINT8:
return np.uint8(self.rng.integers(low=low, high=high, size=shape))
+ elif dtype == DType.INT16:
+ return np.int16(self.rng.integers(low=low, high=high, size=shape))
+ elif dtype == DType.UINT16:
+ return np.uint16(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (