aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-04-21 22:49:57 +0000
committerJerry Ge <jerry.ge@arm.com>2023-05-17 22:46:57 +0000
commit264f7faa59709ffa8117541f5d55c99c5dba967d (patch)
treeae767b3e4375ab87d4323f18b63239a84ac857db /verif/generator
parent7e5968166a5105da30bc11c9241f271cb3dc1da9 (diff)
downloadreference_model-264f7faa59709ffa8117541f5d55c99c5dba967d.tar.gz
Add support for one dimension of size -1 in ReshapeOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ef7607f4266296a1204c5cccdb5be36f345b5ba
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py46
-rw-r--r--verif/generator/tosa_error_if.py58
-rw-r--r--verif/generator/tosa_test_gen.py2
3 files changed, 104 insertions, 2 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 2bbc349..9386ec2 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1878,17 +1878,27 @@ class TosaArgGen:
escape_counter = 0
while found:
newShape = []
+ new_shape_inferred = []
# Generate newShape ensuring it isn't a duplicate
remainingElements = totalElements
shuffledFactors = testGen.rng.permutation(factors)
+ inferred_dim = testGen.rng.integers(1, newRank + 1)
for i in range(1, newRank):
# pick rank-1 factors
newShape.append(shuffledFactors[0])
remainingElements = remainingElements // shuffledFactors[0]
+ if i == inferred_dim:
+ new_shape_inferred.append(-1)
+ else:
+ new_shape_inferred.append(shuffledFactors[0])
shuffledFactors = testGen.rng.permutation(
TosaArgGen.getFactors(remainingElements)
)
newShape.append(remainingElements)
+ if inferred_dim == newRank:
+ new_shape_inferred.append(-1)
+ else:
+ new_shape_inferred.append(remainingElements)
# Check for duplicates
found = False
@@ -1902,7 +1912,41 @@ class TosaArgGen:
break
if not found:
- arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
+ if error_name in [
+ ErrorIf.ReshapeOutputSizeNonInteger,
+ ErrorIf.ReshapeOutputSizeMultiInference,
+ ]:
+ if newRank < 2:
+ # Need at least two dimensions
+ continue
+ # NOTE: Change inferred_dim starting offset from 1 to 0
+ inferred_dim -= 1
+ extra_dim = inferred_dim + testGen.rng.integers(1, newRank)
+ extra_dim = extra_dim % newRank
+ assert extra_dim != inferred_dim
+ if error_name == ErrorIf.ReshapeOutputSizeNonInteger:
+ elements = 1
+ for i, dim_value in enumerate(new_shape_inferred):
+ if i != inferred_dim and i != extra_dim:
+ elements *= dim_value
+ dim_value = new_shape_inferred[extra_dim]
+ while totalElements % (elements * dim_value) == 0:
+ dim_value += 1
+ new_shape_inferred[extra_dim] = dim_value
+ else:
+ assert error_name == ErrorIf.ReshapeOutputSizeMultiInference
+ new_shape_inferred[extra_dim] = -1
+ else:
+ arg_list.append(
+ ("perm{}_rank{}_outdefined".format(p, newRank), [newShape])
+ )
+ if error_name != ErrorIf.TensorSizeInputOutputMismatch:
+ arg_list.append(
+ (
+ "perm{}_rank{}_outinferred".format(p, newRank),
+ [new_shape_inferred],
+ )
+ )
return arg_list
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 8c40371..a0a9203 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -81,6 +81,8 @@ class ErrorIf(object):
KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
FFTInputShapeMismatch = "FFTInputShapeMismatch"
FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
+ ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
+ ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
class TosaErrorIfArgGen:
@@ -1822,13 +1824,17 @@ class TosaErrorValidator:
param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
error_reason = "Input tensor size does not match output tensor size"
+ op = kwargs["op"]
if check:
input_shape = kwargs["input_shape"]
output_shape = kwargs["output_shape"]
+ shape_inferencing = False
+ if -1 in output_shape and op["op"] == Op.RESHAPE:
+ shape_inferencing = True
input_size = np.prod(input_shape)
output_size = np.prod(output_shape)
- if input_size != output_size:
+ if input_size != output_size and not shape_inferencing:
error_result = True
info_dict = {
@@ -2510,6 +2516,56 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evReshapeOutputSizeMultiInference(check=False, **kwargs):
+ error_name = ErrorIf.ReshapeOutputSizeMultiInference
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Reshape output tensor contains more than one inferred dimension"
+
+ if check:
+ output_shape = kwargs["output_shape"]
+ inferences = 0
+ for dim in output_shape:
+ if dim == -1:
+ inferences += 1
+ if inferences > 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evReshapeOutputSizeNonInteger(check=False, **kwargs):
+ error_name = ErrorIf.ReshapeOutputSizeNonInteger
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Reshape inferred output tensor dimension is non-integer"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ input_size = np.prod(input_shape)
+ output_size = 1
+ for dim in output_shape:
+ if dim != -1:
+ output_size *= dim
+ if -1 in output_shape and input_size % output_size != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c8c22c2..7691fdd 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3693,6 +3693,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evReshapeOutputSizeMultiInference,
+ TosaErrorValidator.evReshapeOutputSizeNonInteger,
),
},
"reverse": {