aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-10 14:50:31 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-01-24 13:40:17 +0000
commit261b7b62b959a6c7312d810d9152069fdff69f3e (patch)
tree2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif/generator/tosa_error_if.py
parentc253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff)
downloadreference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz
Add RFFT2d to the reference model
Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py103
1 files changed, 73 insertions, 30 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index c9d35c7..40c5d13 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import math
+
import numpy as np
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import product
@@ -76,6 +78,7 @@ class ErrorIf(object):
CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
+ KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
class TosaErrorIfArgGen:
@@ -548,6 +551,10 @@ class TosaErrorValidator:
):
error_result = True
+ elif op["op"] == Op.RFFT2D:
+ if not all([ty == input_dtype for ty in output_dtype]):
+ error_result = True
+
elif op["op"] in {
Op.CONV2D,
Op.CONV3D,
@@ -665,9 +672,13 @@ class TosaErrorValidator:
error_reason = "Op output list does not match expected output"
if check:
+ op = kwargs["op"]
output_list = kwargs["output_list"]
- # Note this will be incorrect if an operator returns more than one output
- if len(output_list) != 1:
+ expected_length = 1
+ if op["op"] == Op.RFFT2D:
+ expected_length = 2
+
+ if len(output_list) != expected_length:
error_result = True
info_dict = {
@@ -711,7 +722,7 @@ class TosaErrorValidator:
@staticmethod
def evBatchMismatch(check=False, **kwargs):
error_name = ErrorIf.BatchMismatch
- param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
error_reason = "Input batch size not equal to output batch size"
@@ -722,12 +733,15 @@ class TosaErrorValidator:
if check:
input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
- if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
- error_result = True
+ for output in kwargs["result_tensors"]:
+ output_shape = (
+ output.shape
+ ) # Note batch is expected to be the first dim
+ if (len(input_shape) in rank_range) and (
+ input_shape[0] != output_shape[0]
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -751,11 +765,12 @@ class TosaErrorValidator:
if check:
input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
- if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
- error_result = True
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape # Note this is just (N, OH, OW, C)
+ if (len(input_shape) in rank_range) and (
+ input_shape[3] != output_shape[3]
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1044,13 +1059,15 @@ class TosaErrorValidator:
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- output_shape = kwargs["result_tensor"].shape
- if (
- (len(input1_shape) != len(output_shape))
- or (len(input2_shape) != len(output_shape))
- or (len(input3_shape) != len(output_shape))
- ):
- error_result = True
+
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ if (
+ (len(input1_shape) != len(output_shape))
+ or (len(input2_shape) != len(output_shape))
+ or (len(input3_shape) != len(output_shape))
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1074,16 +1091,18 @@ class TosaErrorValidator:
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- output_shape = kwargs["result_tensor"].shape
- for i in range(
- min(len(input1_shape), len(input2_shape), len(input3_shape))
- ):
- if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
- or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
- or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ for i in range(
+ min(len(input1_shape), len(input2_shape), len(input3_shape))
):
- error_result = True
+ if (
+ (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
+ or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
+ or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -2392,6 +2411,30 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evKernelNotPowerOfTwo(check=False, **kwargs):
+ error_name = ErrorIf.KernelNotPowerOfTwo
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "kernel height and/or width not a power of two"
+
+ def is_power_of_two(x):
+ return math.log(x, 2).is_integer()
+
+ if check:
+ shape = kwargs["input_shape"]
+ if len(shape) == 3:
+ valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
+ error_result = not valid_kernel
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod