aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-01-09 00:34:40 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-24 21:01:20 +0000
commit74342e522ec61e85fde64fe801da9e750b3e2d86 (patch)
tree473a02dcbccb5dcf7aee009682454aa2b914bb64 /verif/generator/tosa_error_if.py
parent1f75232dab1b50162ebc420e6e076edeb8a58341 (diff)
downloadreference_model-74342e522ec61e85fde64fe801da9e750b3e2d86.tar.gz
Add conformance testing for shape operators
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Ie80570146601c470a3be7c04a9d6e1016a7c547c
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7f719ee..5874123 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import math
@@ -595,6 +595,10 @@ class TosaErrorValidator:
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
+ elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
+ if output_dtype != DType.SHAPE:
+ error_result = True
+
else:
if output_dtype != input_dtype:
error_result = True
@@ -1109,7 +1113,13 @@ class TosaErrorValidator:
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- if len(input1_shape) == len(input2_shape) == len(input3_shape):
+ op = kwargs["op"]
+ if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
+ output_shape = kwargs["result_tensors"][0].shape
+ if input1_shape != output_shape:
+ error_result = True
+
+ elif len(input1_shape) == len(input2_shape) == len(input3_shape):
calculated_shape = TosaErrorValidator.calculateBroadcastShape(
input3_shape,
TosaErrorValidator.calculateBroadcastShape(