aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-06 14:54:18 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-10 20:01:04 +0000
commit5728713fca4f6e2dff60dad3689e471545e563d2 (patch)
tree848421100f82a33ff57ee3205c369ad75737f7d3 /verif/generator/tosa_error_if.py
parentc1e25f5755997e65ac1a360ec1e875db06040d8d (diff)
downloadreference_model-5728713fca4f6e2dff60dad3689e471545e563d2.tar.gz
Add FFT2d to the reference model
Includes: * FFT2d reference implementation * Basic TOSA tests Change-Id: Ie79fcb713542345d550ec013646810c1e890e388 Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py64
1 files changed, 62 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 93f975d..ee227b3 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -79,6 +79,8 @@ class ErrorIf(object):
CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
+ FFTInputShapeMismatch = "FFTInputShapeMismatch"
+ FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
class TosaErrorIfArgGen:
@@ -562,7 +564,7 @@ class TosaErrorValidator:
):
error_result = True
- elif op["op"] == Op.RFFT2D:
+ elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
if not all([ty == input_dtype for ty in output_dtype]):
error_result = True
@@ -686,7 +688,7 @@ class TosaErrorValidator:
op = kwargs["op"]
output_list = kwargs["output_list"]
expected_length = 1
- if op["op"] == Op.RFFT2D:
+ if op["op"] in [Op.FFT2D, Op.RFFT2D]:
expected_length = 2
if len(output_list) != expected_length:
@@ -2446,6 +2448,64 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evFFTInputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTInputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Mismatch between real and imaginary input shapes"
+
+ if check:
+ input1 = kwargs["input1"]
+ input2 = kwargs["input2"]
+
+ if input1.shape != input2.shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evFFTOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between provided and expected output kernel (H, W) shape"
+ )
+
+ if check:
+ op = kwargs["op"]
+ input_shape = kwargs["input_shape"]
+
+ if len(input_shape) == 3:
+ output_shapes = kwargs["output_shape"]
+
+ # Ignoring batch size (N) from input shape
+ expected_shape = input_shape[1:]
+ if op["op"] == Op.RFFT2D:
+ expected_shape[1] = expected_shape[1] // 2 + 1
+
+ # Ignoring batch size (N) from output shapes
+ output_shape_0 = output_shapes[0][1:]
+ output_shape_1 = output_shapes[1][1:]
+ # Ensure sure the kernel sizes (H, W) of both outputs match the expected
+ if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod