diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-11-02 17:16:25 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-11-10 16:35:17 +0000 |
commit | aee62afba99a74f772b97356fd4c18f3fdf37073 (patch) | |
tree | 47c7f0bc619ae44332f58c786b52460b40bc1a97 /verif | |
parent | bfc53031803338d9f0866f88f1d2deffd4928bcc (diff) | |
download | reference_model-aee62afba99a74f772b97356fd4c18f3fdf37073.tar.gz |
Main Compliance testing for FULLY_CONNECTED
Updated shapes to meet MIN_DOT_PRODUCTS.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I82297917c009b3120306f8a9bb965209d109ef8d
Diffstat (limited to 'verif')
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 19 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 24 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 32 |
3 files changed, 55 insertions, 20 deletions
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index faccf75..bdfc281 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -995,6 +995,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "negative_dim_range": "1,10", @@ -1007,13 +1008,19 @@ "--target-dtype", "bf16", "--fp-values-range", - "-2.0,2.0" + "-max,max", + "--tensor-dim-range", + "15,64" ], [ "--target-dtype", "fp32", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "10,15", "--target-shape", - "1,296", + "100,296", "--target-shape", "65540,2" ], @@ -1025,11 +1032,13 @@ "--target-dtype", "bf16", "--fp-values-range", - "-2.0,2.0", + "-max,max", + "--tensor-dim-range", + "35,64", "--target-shape", - "3,16", + "30,16", "--target-shape", - "1,23" + "100,23" ] ] } diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 4014656..1f54851 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1541,9 +1541,7 @@ class TosaArgGen: @staticmethod def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): - assert isinstance(dtypes, list) or isinstance( - dtypes, tuple - ), f"{dtypes} unexpected" + assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected" input_dtype = dtypes[0] if error_name == ErrorIf.WrongOutputType: @@ -1554,7 +1552,25 @@ class TosaArgGen: else: accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) - return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] + # Set up compliance info + args_dict = { + "acc_type": accum_dtype, + "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC) + "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])), + "shape": shapeList[0], + } + + arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)] + + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + input_dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) + return arg_list @staticmethod def agMatMul(testGen, opName, shapeList, dtype, error_name=None): diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 3180cf5..d1fe11d 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1086,21 +1086,23 @@ class TosaTestGen: def build_fully_connected( self, op, - ifm, - filter, - bias, - accum_dtype, + inputs, + args_dict, validator_fcns=None, error_name=None, qinfo=None, ): - result_tens = OutputShaper.fullyConnectedOp( + assert len(inputs) == 3 + ifm, filter, bias = inputs + accum_dtype = args_dict["acc_type"] + + result_tensor = OutputShaper.fullyConnectedOp( self.ser, self.rng, ifm, filter, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. input_list = [ifm.name, filter.name, bias.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -1115,10 +1117,10 @@ class TosaTestGen: input_shape=ifm.shape, input_dtype=ifm.dtype, weight_dtype=filter.dtype, - output_shape=result_tens.shape, - output_dtype=result_tens.dtype, + output_shape=result_tensor.shape, + output_dtype=result_tensor.dtype, qinfo=qinfo, - result_tensors=[result_tens], + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1130,7 +1132,12 @@ class TosaTestGen: attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, ifm.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_matmul( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -3077,7 +3084,7 @@ class TosaTestGen: "build_fcn": ( build_fully_connected, TosaTensorGen.tgFullyConnected, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agFullyConnected, ), "qgen": TosaQuantGen.qgConv, @@ -3091,6 +3098,9 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, }, "matmul": { "op": Op.MATMUL, |