aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-01-09 00:34:40 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-24 21:01:20 +0000
commit74342e522ec61e85fde64fe801da9e750b3e2d86 (patch)
tree473a02dcbccb5dcf7aee009682454aa2b914bb64 /verif/generator/tosa_test_gen.py
parent1f75232dab1b50162ebc420e6e076edeb8a58341 (diff)
downloadreference_model-74342e522ec61e85fde64fe801da9e750b3e2d86.tar.gz
Add conformance testing for shape operators
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Ie80570146601c470a3be7c04a9d6e1016a7c547c
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py164
1 files changed, 155 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 159ee83..b9352ac 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -167,9 +167,10 @@ class TosaTestGen:
rng = (-128, 128)
elif dtype == DType.INT16:
rng = (-32768, 32768)
- elif dtype in (DType.INT32, DType.SHAPE):
- # restricting too large value for SHAPE
+ elif dtype == DType.INT32:
rng = (-(1 << 31), (1 << 31))
+ elif dtype == DType.SHAPE:
+ rng = tuple(self.args.tensor_shape_range[0:2])
elif dtype == DType.INT48:
rng = (-(1 << 47), (1 << 47))
else:
@@ -190,7 +191,7 @@ class TosaTestGen:
if dtype == DType.BOOL:
return np.bool_(self.rng.choice(a=[False, True], size=shape))
- elif dtype == DType.INT48:
+ elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.FP16, DType.BF16, DType.FP32):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
@@ -1399,7 +1400,10 @@ class TosaTestGen:
def build_concat(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- axis = args_dict["axis"]
+ if op["op"] == Op.CONCAT_SHAPE:
+ axis = 0
+ else:
+ axis = args_dict["axis"]
if error_name != ErrorIf.WrongInputType:
assert type(axis) == int
@@ -1438,9 +1442,12 @@ class TosaTestGen:
):
return None
- attr = ts.TosaSerializerAttribute()
- attr.AxisAttribute(axis)
-
+ if op["op"] == Op.CONCAT:
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+ else:
+ assert op["op"] == Op.CONCAT_SHAPE
+ attr = None
self.ser.addOperator(op["op"], input_list, output_list, attr)
compliance = self.tensorComplianceMetaData(
@@ -2512,6 +2519,52 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_names, output_names, attr)
return results
+ def build_shape_op(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 2
+ a, b = inputs
+
+ result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tensor.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input1=a,
+ input2=b,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_shape=result_tensor.shape,
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ ):
+ return None
+
+ self.ser.addOperator(
+ op["op"],
+ input_list,
+ output_list,
+ )
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
+
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
@@ -2725,12 +2778,12 @@ class TosaTestGen:
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
- elif op["op"] == Op.CONCAT:
+ elif op["op"] in (Op.CONCAT, Op.CONCAT_SHAPE):
dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
- if op["op"] != Op.CONCAT:
+ if op["op"] not in (Op.CONCAT, Op.CONCAT_SHAPE):
assert (
len(shapeList) == num_operands
), "shapeList length {} must match number of operands {}".format(
@@ -4605,6 +4658,78 @@ class TosaTestGen:
TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
+ # Shape
+ "add_shape": {
+ "op": Op.ADD_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgAddSub,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "sub_shape": {
+ "op": Op.SUB_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgAddSub,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "mul_shape": {
+ "op": Op.MUL_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgMul,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "div_shape": {
+ "op": Op.DIV_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_shape_op,
+ TosaTensorGen.tgShape,
+ TosaTensorValuesGen.tvgIntDiv,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (TosaErrorValidator.evDimensionMismatch,),
+ },
+ "concat_shape": {
+ "op": Op.CONCAT_SHAPE,
+ "operands": (2, 0),
+ "build_fcn": (
+ build_concat,
+ TosaTensorGen.tgConcat,
+ TosaTensorValuesGen.tvgConcat,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.SHAPE],
+ "error_if_validators": (),
+ },
+ "const_shape": {
+ "op": Op.CONST_SHAPE,
+ "operands": (0, 1),
+ "build_fcn": (
+ build_const,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
+ "types": [DType.SHAPE],
+ },
}
@@ -5524,3 +5649,24 @@ class OutputShaper:
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))
return outputs
+
+ @staticmethod
+ def addShapeOp(ser, rng, a, b, error_name=None):
+ if error_name != ErrorIf.RankMismatch:
+ assert len(a.shape) == len(b.shape)
+ assert a.dtype == b.dtype
+
+ shape = []
+ for i in range(len(a.shape)):
+ shape.append(a.shape[i])
+
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = gtu.get_wrong_output_type(a.dtype)
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = DType.SHAPE
+ return ser.addOutput(shape, outputDType)