diff options
author | Luke Hutton <luke.hutton@arm.com> | 2023-01-10 14:50:31 +0000 |
---|---|---|
committer | Luke Hutton <luke.hutton@arm.com> | 2023-01-24 13:40:17 +0000 |
commit | 261b7b62b959a6c7312d810d9152069fdff69f3e (patch) | |
tree | 2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif/generator | |
parent | c253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff) | |
download | reference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz |
Add RFFT2d to the reference model
Includes:
* RFFT2d reference implementation
* TFLite framework tests
* Basic TOSA tests
* Serialization submodule upgrade with support for FFT/RFFT
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif/generator')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 37 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 103 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 130 |
3 files changed, 212 insertions, 58 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 4e15b06..fed91f6 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math @@ -417,6 +417,41 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] @staticmethod + def tgRFFT2d(testGen, op, rank, error_name=None): + pl, const = op["operands"] + + if error_name != ErrorIf.WrongRank: + assert rank == 3 + assert pl == 1 and const == 0 + + # IFM dimensions are NHW + ifm_shape = testGen.makeShape(rank) + + # Select nearest lower power of two from input height and width + ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) + ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) + + # Generate an invalid kernel that is not a power of two + if error_name == ErrorIf.KernelNotPowerOfTwo: + # We must increment by 2 if current size is 1 + inc_h = 2 if ifm_shape[1] == 1 else 1 + inc_w = 2 if ifm_shape[2] == 1 else 1 + inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] + selected_inc = testGen.rng.choice(inc_choices) + ifm_shape[1] += selected_inc[0] + ifm_shape[2] += selected_inc[1] + + # Constrict the batch size + if testGen.args.max_batch_size: + ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 + + return [ifm_shape] + + @staticmethod def tgFullyConnected(testGen, op, rank, error_name=None): pl, const = op["operands"] diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index c9d35c7..40c5d13 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1,5 +1,7 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import product @@ -76,6 +78,7 @@ class ErrorIf(object): CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" + KernelNotPowerOfTwo = "KernelNotPowerOfTwo" class TosaErrorIfArgGen: @@ -548,6 +551,10 @@ class TosaErrorValidator: ): error_result = True + elif op["op"] == Op.RFFT2D: + if not all([ty == input_dtype for ty in output_dtype]): + error_result = True + elif op["op"] in { Op.CONV2D, Op.CONV3D, @@ -665,9 +672,13 @@ class TosaErrorValidator: error_reason = "Op output list does not match expected output" if check: + op = kwargs["op"] output_list = kwargs["output_list"] - # Note this will be incorrect if an operator returns more than one output - if len(output_list) != 1: + expected_length = 1 + if op["op"] == Op.RFFT2D: + expected_length = 2 + + if len(output_list) != expected_length: error_result = True info_dict = { @@ -711,7 +722,7 @@ class TosaErrorValidator: @staticmethod def evBatchMismatch(check=False, **kwargs): error_name = ErrorIf.BatchMismatch - param_reqs = {"rank": [4, 4], "dtype": None, "shape": None} + param_reqs = {"rank": None, "dtype": None, "shape": None} error_result = False error_reason = "Input batch size not equal to output batch size" @@ -722,12 +733,15 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = ( + output.shape + ) # Note batch is expected to be the first dim + if (len(input_shape) in rank_range) and ( + input_shape[0] != output_shape[0] + ): + error_result = True info_dict = { "error_name": error_name, @@ -751,11 +765,12 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = output.shape # Note this is just (N, OH, OW, C) + if (len(input_shape) in rank_range) and ( + input_shape[3] != output_shape[3] + ): + error_result = True info_dict = { "error_name": error_name, @@ -1044,13 +1059,15 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - if ( - (len(input1_shape) != len(output_shape)) - or (len(input2_shape) != len(output_shape)) - or (len(input3_shape) != len(output_shape)) - ): - error_result = True + + for output in kwargs["result_tensors"]: + output_shape = output.shape + if ( + (len(input1_shape) != len(output_shape)) + or (len(input2_shape) != len(output_shape)) + or (len(input3_shape) != len(output_shape)) + ): + error_result = True info_dict = { "error_name": error_name, @@ -1074,16 +1091,18 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - for i in range( - min(len(input1_shape), len(input2_shape), len(input3_shape)) - ): - if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) - or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) - or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + + for output in kwargs["result_tensors"]: + output_shape = output.shape + for i in range( + min(len(input1_shape), len(input2_shape), len(input3_shape)) ): - error_result = True + if ( + (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) + or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) + or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + ): + error_result = True info_dict = { "error_name": error_name, @@ -2392,6 +2411,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evKernelNotPowerOfTwo(check=False, **kwargs): + error_name = ErrorIf.KernelNotPowerOfTwo + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "kernel height and/or width not a power of two" + + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + if check: + shape = kwargs["input_shape"] + if len(shape) == 3: + valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2]) + error_result = not valid_kernel + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c29763b..fddf942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -255,7 +255,7 @@ class TosaTestGen: input_dtype=a.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -293,7 +293,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -333,7 +333,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -378,7 +378,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -414,7 +414,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -448,7 +448,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -487,7 +487,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -523,7 +523,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -582,7 +582,7 @@ class TosaTestGen: stride=stride, pad=pad, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -938,7 +938,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -980,7 +980,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1016,7 +1016,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1064,7 +1064,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1122,7 +1122,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1153,7 +1153,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1199,7 +1199,7 @@ class TosaTestGen: input_dtype=a[0].dtype, output_dtype=result_tens.dtype, inputs=a, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1250,7 +1250,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1283,7 +1283,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1318,7 +1318,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1356,7 +1356,7 @@ class TosaTestGen: perms=perms, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1391,7 +1391,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, start=start, size=size, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1425,7 +1425,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1474,7 +1474,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1519,7 +1519,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values_in.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1580,7 +1580,7 @@ class TosaTestGen: border=border, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -1628,7 +1628,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=val.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1774,7 +1774,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -2083,6 +2083,38 @@ class TosaTestGen: return acc_out + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): + results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) + + input_names = [val.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape=val.shape, + input_dtype=val.dtype, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + self.ser.addOperator(op["op"], input_names, output_names) + return results + def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): @@ -3897,6 +3929,27 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "rfft2d": { + "op": Op.RFFT2D, + "operands": (1, 0), + "rank": (3, 3), + "build_fcn": ( + build_rfft2d, + TosaTensorGen.tgRFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agNone, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + ), + }, } @@ -4717,3 +4770,26 @@ class OutputShaper: out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) + + @staticmethod + def rfft2dOp(serializer, rng, value, error_name=None): + outputs = [] + + input_shape = value.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1] + + output_dtype = value.dtype + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + incorrect_batch = input_shape[0] + rng.integers(1, 10) + output_shape = [incorrect_batch, *input_shape[1:]] + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs |