aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tosa_verif_framework_compiler_runner.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_compiler_runner.py')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 3597f2a..c55864a 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import glob
@@ -483,6 +483,20 @@ def run_test(args, test, framework):
except KeyError:
assert 0, "fail to load tflite result numpy"
+ # TOSA has no notion of complex datatypes, it represents complex values using two
+ # fp32 output tensors representing real and imaginary values. When legalizing
+ # complex operations from frameworks, these two output tensors are combined into
+ # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
+ # represents the real and imaginary parts of a complex value. This is completed
+ # by inserting reshape and concatenate TOSA operations during the legalization to
+ # maintain a one-to-one correspondance with framework outputs, thus simplifying
+ # legalization. Here tf_result should also match this format before being
+ # compared to the ref model output.
+ if tf_result.dtype == np.complex64:
+ ifm_shape = tf_result.shape + (2,)
+ tf_result = tf_result.view(np.float32)
+ tf_result = tf_result.reshape(ifm_shape)
+
# Generate test descriptor per flatbuffer generation
# Input .npy will be shared across different frameworks
# Output .npy will be generated in its corresponding flatbuffer