aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-08 19:45:26 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-04-13 17:32:14 +0100
commit714aa6039a7e3585bf81ac90ce301767c08295af (patch)
treeb50dd8a4a8faa70df8e91397dec08530afa1d770
parentabf8718d3bce6c76fb281d3911f566cd90c44f28 (diff)
downloadreference_model-714aa6039a7e3585bf81ac90ce301767c08295af.tar.gz
Add framework tests for tfl.real and tfl.imag
Change-Id: I665acac9b5171efd0c5a2b68b516609048f6e187 Signed-off-by: Luke Hutton <luke.hutton@arm.com>
-rw-r--r--verif/frameworks/tensor_gen.py16
-rw-r--r--verif/frameworks/test_builder.py14
-rw-r--r--verif/frameworks/test_gen_utils.py2
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py8
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py22
5 files changed, 59 insertions, 3 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index f8d50a8..60e17ce 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -56,6 +56,10 @@ class TGen:
return np.uint32(rng.integers(low=0, high=RAND_INT_MAX, size=shape))
if dtype == tf.bool:
return np.bool_(rng.choice(a=[False, True], size=shape))
+ if dtype == tf.complex64:
+ return TGen.getRand(shape, np.float32, rng) + 1j * TGen.getRand(
+ shape, np.float32, rng
+ )
raise Exception("Unsupported type: {}".format(dtype))
@@ -305,5 +309,13 @@ class TGen:
if len(shape) != 3:
return [], []
- tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
- return tf_placeholders, []
+ return TGen.tgBasic(op, shape, dtype, rng)
+
+ @staticmethod
+ def tgComplexComponents(op, shape, dtype, rng):
+ # Temporarily require up to rank 3 shape, due to
+ # slice maximum rank limitiation.
+ if len(shape) > 3:
+ return [], []
+
+ return TGen.tgBasic(op, shape, dtype, rng)
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 6302865..f872888 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1225,3 +1225,17 @@ class TBuilder:
def eval(self, a):
return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
+
+ class Real:
+ def __init__(self, name):
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.math.real(a, name=self.result_name)
+
+ class Imag:
+ def __init__(self, name):
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.math.imag(a, name=self.result_name)
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
index 2d8e5d6..6a59848 100644
--- a/verif/frameworks/test_gen_utils.py
+++ b/verif/frameworks/test_gen_utils.py
@@ -30,6 +30,8 @@ def get_shape_str(shape, dtype):
shape_name = shape_name + "_qi16"
elif dtype == tf.quint16:
shape_name = shape_name + "_qu16"
+ elif dtype == tf.complex64:
+ shape_name = shape_name + "_c64"
else:
raise Exception("Unsupported type: {}".format(dtype))
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index c55864a..71723ae 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -384,7 +384,13 @@ def run_test(args, test, framework):
while len(list(ifm_np.shape)) < len(test_desc["ifm_shape"][i]):
ifm_np = np.expand_dims(ifm_np, axis=0)
- assert list(ifm_np.shape) == test_desc["ifm_shape"][i]
+ # After legalization, complex tensors are expected to be represented
+ # as a single floating point tensor of shape [?, ..., ?, 2].
+ expected_shape = test_desc["ifm_shape"][i]
+ if test.endswith("c64"):
+ expected_shape.append(2)
+
+ assert list(ifm_np.shape) == expected_shape
reference_runner_ifm_name.append(ifm_tensor_name)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 0741686..fffb842 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -841,6 +841,20 @@ TF_OP_LIST = {
"tflite": TYPE_F,
},
},
+ "real": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.Real, TGen.tgComplexComponents, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.complex64],
+ },
+ },
+ "imag": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.Imag, TGen.tgComplexComponents, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.complex64],
+ },
+ },
}
# Shapes to be tested; default can be overwritten
@@ -1154,6 +1168,14 @@ def run_unit_test(
# 1. Saved out numpy array directly
for idx, (name, val) in enumerate(placeholders):
placeholder_vals.append(tf.convert_to_tensor(val))
+
+ # Complex tensors are expected to be repsesented by a
+ # single floating point tensor of shape [?, ..., ?, 2].
+ if val.dtype == np.complex64:
+ val_shape = val.shape + (2,)
+ val = val.view(np.float32)
+ val = val.reshape(val_shape)
+
np.save(
os.path.join(test_dir, placeholder_npy_filenames[idx]), val, False
)