aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-22 11:53:48 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-28 20:08:57 +0000
commita4e48ca7b032992ca0110900935c08d7cf860cd3 (patch)
treea58c8617390225ecc107721d9b5ff87c2bdb01b0 /verif
parent2226f90d5a6c48a975045bc9e0419113ce764aaf (diff)
downloadreference_model-a4e48ca7b032992ca0110900935c08d7cf860cd3.tar.gz
Update rank limits for SLICE, TILE and TRANSPOSE
Updated to align with corresponding changes to the spec. In addition, some ERROR_IF tests have been updated to match the checks specified by the spec, including: PAD, SLICE, TILE, TRANSPOSE. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8
Diffstat (limited to 'verif')
-rw-r--r--verif/generator/tosa_arg_gen.py1
-rw-r--r--verif/generator/tosa_error_if.py16
-rw-r--r--verif/generator/tosa_test_gen.py48
-rw-r--r--verif/generator/tosa_utils.py11
4 files changed, 58 insertions, 18 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 75ca634..9209d9c 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -176,6 +176,7 @@ class TosaTensorGen:
for i in range(pl + const):
shape_list.append(shape.copy())
+ # Generates an input rank mismatch for operators with more than one input
if error_name == ErrorIf.RankMismatch:
if rank == 1 and i != 1:
shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index ee227b3..b19d5e9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1067,7 +1067,9 @@ class TosaErrorValidator:
if check:
input1_shape = kwargs["input1"].shape
- input2_shape = kwargs["input2"].shape
+ input2_shape = (
+ kwargs["input2"].shape if "input2" in kwargs else input1_shape
+ )
# In case of SELECT op
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
@@ -1921,11 +1923,13 @@ class TosaErrorValidator:
input_shape = kwargs["input_shape"]
output_shape = kwargs["output_shape"]
size = kwargs["size"]
- rank = len(input_shape)
- if len(size) == rank:
- for index in range(rank):
- if size[index] != output_shape[index]:
- error_result = True
+
+ if len(input_shape) == len(output_shape):
+ rank = len(input_shape)
+ if len(size) == rank:
+ for index in range(rank):
+ if size[index] != output_shape[index]:
+ error_result = True
info_dict = {
"error_name": error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index a768da0..7fef942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_rank_mismatch_shape
from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
@@ -1263,6 +1264,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1369,6 +1371,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1404,6 +1407,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1438,6 +1442,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -3657,6 +3662,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"reshape": {
@@ -3699,7 +3706,7 @@ class TosaTestGen:
"slice": {
"op": Op.SLICE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
@@ -3718,11 +3725,13 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
),
},
"tile": {
"op": Op.TILE,
"operands": (1, 0),
+ "rank": (1, 6),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
@@ -3735,12 +3744,14 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
@@ -3755,6 +3766,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
},
# Data nodes
@@ -4539,6 +4553,8 @@ class OutputShaper:
if error_name == ErrorIf.PadOutputShapeMismatch:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4583,7 +4599,7 @@ class OutputShaper:
return ser.addOutput(output_shape, outputDType)
@staticmethod
- def sliceOp(ser, rng, a, start, size, error_name=None):
+ def sliceOp(ser, rng, input, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4595,13 +4611,13 @@ class OutputShaper:
DType.FP16,
DType.BF16,
]
- wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
- outputDType = a.dtype
+ outputDType = input.dtype
+ output_shape = size.copy()
if error_name == ErrorIf.SizeOutputShapeMismatch:
- output_shape = size.copy()
for index in range(len(output_shape)):
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
@@ -4609,8 +4625,10 @@ class OutputShaper:
output_shape[index] = output_shape[index] + rng.choice(
[-2, -1, 1, 2]
)
- else:
- output_shape = size.copy()
+ elif error_name == ErrorIf.InputSizeStartLengthMismatch:
+ output_shape = input.shape.copy()
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@@ -4623,6 +4641,9 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
+ if error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4646,13 +4667,16 @@ class OutputShaper:
assert len(perms) == len(output_shape)
- if error_name == ErrorIf.IndexOutsideBounds:
- for i in range(len(output_shape)):
- output_shape[i] = a.shape[0]
- else:
+ if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
for i in range(len(output_shape)):
output_shape[i] = a.shape[perms[i]]
+ if error_name == ErrorIf.TensorSizeInputOutputMismatch:
+ for i in range(len(output_shape)):
+ output_shape[i] += rng.integers(1, 10)
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 29ae898..8ff62f1 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype):
return rng.choice(a=incorrect_types)
+def get_rank_mismatch_shape(rng, output_shape):
+ """
+ Extends the rank of the provided output_shape by
+ an arbitrary amount but ensures the total element
+ count remains the same.
+ """
+ rank_modifier = rng.choice([1, 2, 3])
+ output_shape += [1] * rank_modifier
+ return output_shape
+
+
def float32_is_valid_bfloat16(f):
"""Return True if float value is valid bfloat16."""
f32_bits = get_float32_bitstring(f)