aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-01-09 00:34:40 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-24 21:01:20 +0000
commit74342e522ec61e85fde64fe801da9e750b3e2d86 (patch)
tree473a02dcbccb5dcf7aee009682454aa2b914bb64
parent1f75232dab1b50162ebc420e6e076edeb8a58341 (diff)
downloadreference_model-74342e522ec61e85fde64fe801da9e750b3e2d86.tar.gz
Add conformance testing for shape operators
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Ie80570146601c470a3be7c04a9d6e1016a7c547c
-rw-r--r--verif/conformance/test_select.py38
-rw-r--r--verif/conformance/tosa_base_profile_ops_info.json175
-rw-r--r--verif/generator/tosa_arg_gen.py79
-rw-r--r--verif/generator/tosa_error_if.py14
-rw-r--r--verif/generator/tosa_test_gen.py164
5 files changed, 441 insertions, 29 deletions
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index cebdf62..55eef58 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
"""Select generated tests."""
import argparse
@@ -437,6 +437,12 @@ class AddOperator(Operator):
name = "add"
+class AddShapeOperator(Operator):
+ """Test selector for the ADD_SHAPE operator."""
+
+ name = "add_shape"
+
+
class ArgmaxOperator(Operator):
"""Test selector for the ARGMAX operator."""
@@ -507,6 +513,12 @@ class ConcatOperator(Operator):
param_names = ["shape", "type", "axis"]
+class ConcatShapeOperator(Operator):
+ """Test selector for the CONCAT_SHAPE operator."""
+
+ name = "concat_shape"
+
+
class CondIfOperator(Operator):
"""Test selector for the COND_IF operator."""
@@ -520,6 +532,12 @@ class ConstOperator(Operator):
name = "const"
+class ConstShapeOperator(Operator):
+ """Test selector for the CONST_SHAPE operator."""
+
+ name = "const_shape"
+
+
class Conv2dOperator(Operator):
"""Test selector for the CONV2D operator."""
@@ -548,6 +566,12 @@ class DimOeprator(Operator):
param_names = ["shape", "type", "axis"]
+class DivShapeOperator(Operator):
+ """Test selector for the DIV_SHAPE operator."""
+
+ name = "div_shape"
+
+
class EqualOperator(Operator):
"""Test selector for the EQUAL operator."""
@@ -696,6 +720,12 @@ class MulOperator(Operator):
param_names = ["shape", "type", "perm", "shift"]
+class MulShapeOperator(Operator):
+ """Test selector for the MUL_SHAPE operator."""
+
+ name = "mul_shape"
+
+
class NegateOperator(Operator):
"""Test selector for the Negate operator."""
@@ -849,6 +879,12 @@ class SubOperator(Operator):
name = "sub"
+class SubShapeOperator(Operator):
+ """Test selector for the SUB_SHAPE operator."""
+
+ name = "sub_shape"
+
+
class TableOperator(Operator):
"""Test selector for the TABLE operator."""
diff --git a/verif/conformance/tosa_base_profile_ops_info.json b/verif/conformance/tosa_base_profile_ops_info.json
index b186b06..ec51324 100644
--- a/verif/conformance/tosa_base_profile_ops_info.json
+++ b/verif/conformance/tosa_base_profile_ops_info.json
@@ -129,6 +129,35 @@
}
}
},
+ "add_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape"
+ ]
+ }
+ }
+ },
"argmax": {
"group": "tensor",
"profile": [
@@ -974,6 +1003,36 @@
}
}
},
+ "concat_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--target-rank",
+ "1",
+ "--target-shape",
+ "1",
+ "--num-const-inputs-concat",
+ "2"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ ]
+ }
+ }
+ },
"cond_if": {
"group": "control_flow",
"profile": [
@@ -1080,6 +1139,35 @@
}
}
},
+ "const_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--target-rank",
+ "1",
+ "--target-shape",
+ "1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ ]
+ }
+ }
+ },
"conv2d": {
"group": "tensor",
"profile": [
@@ -1374,6 +1462,35 @@
}
}
},
+ "div_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape"
+ ]
+ }
+ }
+ },
"equal": {
"group": "comparison",
"profile": [
@@ -2542,6 +2659,35 @@
}
}
},
+ "mul_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape"
+ ]
+ }
+ }
+ },
"negate": {
"group": "ew_unary",
"profile": [
@@ -3502,6 +3648,35 @@
}
}
},
+ "sub_shape": {
+ "group": "shape",
+ "profile": [
+ "tosa-bi",
+ "tosa-mi"
+ ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "shape",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape"
+ ]
+ }
+ }
+ },
"table": {
"group": "ew_binary",
"profile": [
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a655a50..f598377 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -622,6 +622,28 @@ class TosaTensorGen:
return new_shapeList
+ @staticmethod
+ def tgShape(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+ shape = [rank]
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ # Generates an input rank mismatch for operators with more than one input
+ if error_name == ErrorIf.RankMismatch:
+ if rank == 1 and i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+ elif i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
+ return shape_list
+
class TosaTensorValuesGen:
"""Tensor Value generators create the random data for each tensor in each test."""
@@ -891,7 +913,7 @@ class TosaTensorValuesGen:
@staticmethod
def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
- if dtypeList[0] == DType.INT32 and error_name is None:
+ if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
# Make sure the integer operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
op = testGen.TOSA_OP_LIST[opName]
@@ -900,9 +922,10 @@ class TosaTensorValuesGen:
pCount == 2 and cCount == 0
), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
tens_ser_list = []
- add = op["op"] == Op.ADD
- a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
- b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+ add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
+ data_range = testGen.args.tensor_shape_range
+ a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
+ b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
if add:
res_arr = np.add(a_arr, b_arr, dtype=np.int64)
else:
@@ -1138,12 +1161,15 @@ class TosaTensorValuesGen:
tens_ser_list = []
# Make sure multiply result in int32 range
- shift = argsDict["shift"]
+ if dtypeList[0] == DType.SHAPE:
+ shift = 0
+ else:
+ shift = argsDict["shift"]
if dtypeList[0] == DType.INT8:
num_bits = 8
elif dtypeList[0] == DType.INT16:
num_bits = 16
- elif dtypeList[0] == DType.INT32:
+ elif dtypeList[0] in (DType.INT32, DType.SHAPE):
num_bits = 32
elif error_name == ErrorIf.WrongInputType:
num_bits = 8
@@ -1151,8 +1177,12 @@ class TosaTensorValuesGen:
raise Exception("OpMul: invalid input dtype")
for idx, shape in enumerate(shapeList[:]):
- low = -(2 ** (num_bits - 1))
- high = (2 ** (num_bits - 1)) - 1
+ if dtypeList[idx] == DType.SHAPE:
+ low = testGen.args.tensor_shape_range[0]
+ high = testGen.args.tensor_shape_range[1]
+ else:
+ low = -(2 ** (num_bits - 1))
+ high = (2 ** (num_bits - 1)) - 1
a_arr = np.int32(
testGen.rng.integers(low=low, high=high, size=shapeList[0])
@@ -1182,12 +1212,20 @@ class TosaTensorValuesGen:
a_arr = a_arr // 2
b_arr = b_arr // 2
- tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
- )
+ if dtypeList[0] == DType.SHAPE:
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
+ )
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
+ )
+ else:
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
@@ -1199,9 +1237,16 @@ class TosaTensorValuesGen:
if testGen.args.num_const_inputs_concat == 0:
count = len(shapeList)
- shapeList = TosaTensorGen.tgConcatConstInput(
- testGen, shapeList, argsDict["axis"], error_name
- )
+ op = testGen.TOSA_OP_LIST[opName]
+ if op["op"] == Op.CONCAT_SHAPE:
+ # Set the axis to 0
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ testGen, shapeList, 0, error_name
+ )
+ else:
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ testGen, shapeList, argsDict["axis"], error_name
+ )
# Override default pCount/cCount for operator
argsDict["p_count"] = count
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7f719ee..5874123 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2023, ARM Limited.
+# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import math
@@ -595,6 +595,10 @@ class TosaErrorValidator:
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
+ elif op["op"] in {Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE}:
+ if output_dtype != DType.SHAPE:
+ error_result = True
+
else:
if output_dtype != input_dtype:
error_result = True
@@ -1109,7 +1113,13 @@ class TosaErrorValidator:
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- if len(input1_shape) == len(input2_shape) == len(input3_shape):
+ op = kwargs["op"]
+ if op["op"] in (Op.ADD_SHAPE, Op.SUB_SHAPE, Op.MUL_SHAPE, Op.DIV_SHAPE):
+ output_shape = kwargs["result_tensors"][0].shape
+ if input1_shape != output_shape:
+ error_result = True
+
+ elif len(input1_shape) == len(input2_shape) == len(input3_shape):
calculated_shape = TosaErrorValidator.calculateBroadcastShape(
input3_shape,
TosaErrorValidator.calculateBroadcastShape(
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 159ee83..b9352ac 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -167,9 +167,10 @@ class TosaTestGen:
rng = (-128, 128)
elif dtype == DType.INT16:
rng = (-32768, 32768)
- elif dtype in (DType.INT32, DType.SHAPE):
- # restricting too large value for SHAPE
+ elif dtype == DType.INT32:
rng = (-(1 << 31), (1 << 31))
+ elif dtype == DType.SHAPE:
+ rng = tuple(self.args.tensor_shape_range[0:2])
elif dtype == DType.INT48:
rng = (-(1 << 47), (1 << 47))
else:
@@ -190,7 +191,7 @@ class TosaTestGen:
if dtype == DType.BOOL:
return np.bool_(self.rng.choice(a=[False, True], size=shape))
- elif dtype == DType.INT48:
+ elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.FP16, DType.BF16, DType.FP32):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
@@ -1399,7 +1400,10 @@ class TosaTestGen:
def build_concat(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- axis = args_dict["axis"]
+ if op["op"] == Op.CONCAT_SHAPE:
+ axis = 0
+ else:
+ axis = args_dict["axis"]
if error_name != ErrorIf.WrongInputType:
assert type(axis) == int
@@ -1438,9 +1442,12 @@ class TosaTestGen:
):
return None
- attr = ts.TosaSerializerAttribute()
- attr.AxisAttribute(axis)
-
+ if op["op"] == Op.CONCAT:
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+ else:
+ assert op["op"] == Op.CONCAT_SHAPE
+ attr = None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
@@ -2512,6 +2519,52 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_names, output_names, attr)
return results
+ def build_shape_op(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 2
+ a, b = inputs
+
+ result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tensor.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input1=a,
+ input2=b,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_shape=result_tensor.shape,
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ ):
+ return None
+
+ self.ser.addOperator(
+ op["op"],
+ input_list,
+ output_list,
+ )
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
+
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
@@ -2725,12 +2778,12 @@ class TosaTestGen:
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
- elif op["op"] == Op.CONCAT:
+ elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
- if op["op"] != Op.CONCAT:
+ if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
assert (
len(shapeList) == num_operands
), "shapeList length {} must match number of operands {}".format(
@@ -4605,6 +4658,78 @@ class TosaTestGen:
TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
+ # Shape
+ "add_shape": {
+ "op": Op.ADD_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgAddSub,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "sub_shape": {
+ "op": Op.SUB_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgAddSub,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "mul_shape": {
+ "op": Op.MUL_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgMul,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "div_shape": {
+ "op": Op.DIV_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgIntDiv,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "concat_shape": {
+ "op": Op.CONCAT_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_concat,
+ TosaTensorGen.tgConcat,
+ TosaTensorValuesGen.tvgConcat,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (),
+ },
+ "const_shape": {
+ "op": Op.CONST_SHAPE,
+ "operands": (0, 1),
+ "build_fcn": (
+ build_const,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
+ "types": [DType.SHAPE],
+ },
}
@@ -5524,3 +5649,24 @@ class OutputShaper:
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs
+
+ @staticmethod
+ def addShapeOp(ser, rng, a, b, error_name=None):
+ if error_name != ErrorIf.RankMismatch:
+ assert len(a.shape) == len(b.shape)
+ assert a.dtype == b.dtype
+
+ shape = []
+ for i in range(len(a.shape)):
+ shape.append(a.shape[i])
+
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = DType.SHAPE
+ return ser.addOutput(shape, outputDType)