aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py64
1 files changed, 62 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 93f975d..ee227b3 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -79,6 +79,8 @@ class ErrorIf(object):
CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
+ FFTInputShapeMismatch = "FFTInputShapeMismatch"
+ FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
class TosaErrorIfArgGen:
@@ -562,7 +564,7 @@ class TosaErrorValidator:
):
error_result = True
- elif op["op"] == Op.RFFT2D:
+ elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
if not all([ty == input_dtype for ty in output_dtype]):
error_result = True
@@ -686,7 +688,7 @@ class TosaErrorValidator:
op = kwargs["op"]
output_list = kwargs["output_list"]
expected_length = 1
- if op["op"] == Op.RFFT2D:
+ if op["op"] in [Op.FFT2D, Op.RFFT2D]:
expected_length = 2
if len(output_list) != expected_length:
@@ -2446,6 +2448,64 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evFFTInputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTInputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Mismatch between real and imaginary input shapes"
+
+ if check:
+ input1 = kwargs["input1"]
+ input2 = kwargs["input2"]
+
+ if input1.shape != input2.shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evFFTOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between provided and expected output kernel (H, W) shape"
+ )
+
+ if check:
+ op = kwargs["op"]
+ input_shape = kwargs["input_shape"]
+
+ if len(input_shape) == 3:
+ output_shapes = kwargs["output_shape"]
+
+ # Ignoring batch size (N) from input shape
+ expected_shape = input_shape[1:]
+ if op["op"] == Op.RFFT2D:
+ expected_shape[1] = expected_shape[1] // 2 + 1
+
+ # Ignoring batch size (N) from output shapes
+ output_shape_0 = output_shapes[0][1:]
+ output_shape_1 = output_shapes[1][1:]
+ # Ensure sure the kernel sizes (H, W) of both outputs match the expected
+ if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod