diff options
author | Luke Hutton <luke.hutton@arm.com> | 2023-01-10 14:50:31 +0000 |
---|---|---|
committer | Luke Hutton <luke.hutton@arm.com> | 2023-01-24 13:40:17 +0000 |
commit | 261b7b62b959a6c7312d810d9152069fdff69f3e (patch) | |
tree | 2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif | |
parent | c253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff) | |
download | reference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz |
Add RFFT2d to the reference model
Includes:
* RFFT2d reference implementation
* TFLite framework tests
* Basic TOSA tests
* Serialization submodule upgrade with support for FFT/RFFT
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif')
-rw-r--r-- | verif/frameworks/arg_gen.py | 30 | ||||
-rw-r--r-- | verif/frameworks/tensor_gen.py | 9 | ||||
-rw-r--r-- | verif/frameworks/test_builder.py | 8 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_compiler_runner.py | 16 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 10 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 37 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 103 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 130 |
8 files changed, 282 insertions, 61 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index d81c3dd..61a1de0 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -1,5 +1,7 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np @@ -851,3 +853,29 @@ class ArgGen: else: axes.append(["_axis_m{}".format(-i), [i]]) return axes + + def agRFFT2d(op, shape, rng): + args = [] + + # Must be rank 3 input tensor + if len(shape) != 3: + return [] + + # Check rfft2d with enforced fft_length + for fft_length_h in [2, 32]: + for fft_length_w in [2, 8, 16]: + fft_length = [fft_length_h, fft_length_w] + args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]]) + + # Check rfft2d with no fft_length provided (fft_length=None). + # In this case, the height and width of the input should be + # used for the calculation. Therefore, we need to check that + # the input shape is already a power of two. + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + height, width = shape[1:3] + if is_power_of_two(height) and is_power_of_two(width): + args.append(["_fft_length_None", [None]]) + + return args diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 767989e..c534a58 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -274,3 +274,12 @@ class TGen: ) return tf_placeholders, tf_consts + + @staticmethod + def tgRFFT2d(op, shape, dtype, rng): + # Require rank 3 shape + if len(shape) != 3: + return [], [] + + tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))] + return tf_placeholders, [] diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 8870f41..6e7b6a5 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1243,3 +1243,11 @@ class TBuilder: def eval(self, a): return self.dense(a) + + class RFFT2d: + def __init__(self, fft_length, name): + self.fft_length = fft_length + self.result_name = name + + def eval(self, a): + return tf.signal.rfft2d(a, self.fft_length, name=self.result_name) diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py index 3597f2a..c55864a 100755 --- a/verif/frameworks/tosa_verif_framework_compiler_runner.py +++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import glob @@ -483,6 +483,20 @@ def run_test(args, test, framework): except KeyError: assert 0, "fail to load tflite result numpy" + # TOSA has no notion of complex datatypes, it represents complex values using two + # fp32 output tensors representing real and imaginary values. When legalizing + # complex operations from frameworks, these two output tensors are combined into + # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values + # represents the real and imaginary parts of a complex value. This is completed + # by inserting reshape and concatenate TOSA operations during the legalization to + # maintain a one-to-one correspondance with framework outputs, thus simplifying + # legalization. Here tf_result should also match this format before being + # compared to the ref model output. + if tf_result.dtype == np.complex64: + ifm_shape = tf_result.shape + (2,) + tf_result = tf_result.view(np.float32) + tf_result = tf_result.reshape(ifm_shape) + # Generate test descriptor per flatbuffer generation # Input .npy will be shared across different frameworks # Output .npy will be generated in its corresponding flatbuffer diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 5b8856d..36ddda5 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import os @@ -839,6 +839,13 @@ TF_OP_LIST = { ] }, }, + "rfft2d": { + "operands": (1, 0), + "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), + "types": { + "tflite": TYPE_F, + }, + }, } # Shapes to be tested; default can be overwritten @@ -847,6 +854,7 @@ shape_list = [ (64,), (14, 19), (13, 21, 3), + (1, 8, 16), (1, 4, 4, 4), (1, 8, 4, 17), (1, 4, 8, 19), diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 4e15b06..fed91f6 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math @@ -417,6 +417,41 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] @staticmethod + def tgRFFT2d(testGen, op, rank, error_name=None): + pl, const = op["operands"] + + if error_name != ErrorIf.WrongRank: + assert rank == 3 + assert pl == 1 and const == 0 + + # IFM dimensions are NHW + ifm_shape = testGen.makeShape(rank) + + # Select nearest lower power of two from input height and width + ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) + ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) + + # Generate an invalid kernel that is not a power of two + if error_name == ErrorIf.KernelNotPowerOfTwo: + # We must increment by 2 if current size is 1 + inc_h = 2 if ifm_shape[1] == 1 else 1 + inc_w = 2 if ifm_shape[2] == 1 else 1 + inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] + selected_inc = testGen.rng.choice(inc_choices) + ifm_shape[1] += selected_inc[0] + ifm_shape[2] += selected_inc[1] + + # Constrict the batch size + if testGen.args.max_batch_size: + ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 + + return [ifm_shape] + + @staticmethod def tgFullyConnected(testGen, op, rank, error_name=None): pl, const = op["operands"] diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index c9d35c7..40c5d13 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1,5 +1,7 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import product @@ -76,6 +78,7 @@ class ErrorIf(object): CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" + KernelNotPowerOfTwo = "KernelNotPowerOfTwo" class TosaErrorIfArgGen: @@ -548,6 +551,10 @@ class TosaErrorValidator: ): error_result = True + elif op["op"] == Op.RFFT2D: + if not all([ty == input_dtype for ty in output_dtype]): + error_result = True + elif op["op"] in { Op.CONV2D, Op.CONV3D, @@ -665,9 +672,13 @@ class TosaErrorValidator: error_reason = "Op output list does not match expected output" if check: + op = kwargs["op"] output_list = kwargs["output_list"] - # Note this will be incorrect if an operator returns more than one output - if len(output_list) != 1: + expected_length = 1 + if op["op"] == Op.RFFT2D: + expected_length = 2 + + if len(output_list) != expected_length: error_result = True info_dict = { @@ -711,7 +722,7 @@ class TosaErrorValidator: @staticmethod def evBatchMismatch(check=False, **kwargs): error_name = ErrorIf.BatchMismatch - param_reqs = {"rank": [4, 4], "dtype": None, "shape": None} + param_reqs = {"rank": None, "dtype": None, "shape": None} error_result = False error_reason = "Input batch size not equal to output batch size" @@ -722,12 +733,15 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = ( + output.shape + ) # Note batch is expected to be the first dim + if (len(input_shape) in rank_range) and ( + input_shape[0] != output_shape[0] + ): + error_result = True info_dict = { "error_name": error_name, @@ -751,11 +765,12 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = output.shape # Note this is just (N, OH, OW, C) + if (len(input_shape) in rank_range) and ( + input_shape[3] != output_shape[3] + ): + error_result = True info_dict = { "error_name": error_name, @@ -1044,13 +1059,15 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - if ( - (len(input1_shape) != len(output_shape)) - or (len(input2_shape) != len(output_shape)) - or (len(input3_shape) != len(output_shape)) - ): - error_result = True + + for output in kwargs["result_tensors"]: + output_shape = output.shape + if ( + (len(input1_shape) != len(output_shape)) + or (len(input2_shape) != len(output_shape)) + or (len(input3_shape) != len(output_shape)) + ): + error_result = True info_dict = { "error_name": error_name, @@ -1074,16 +1091,18 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - for i in range( - min(len(input1_shape), len(input2_shape), len(input3_shape)) - ): - if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) - or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) - or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + + for output in kwargs["result_tensors"]: + output_shape = output.shape + for i in range( + min(len(input1_shape), len(input2_shape), len(input3_shape)) ): - error_result = True + if ( + (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) + or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) + or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + ): + error_result = True info_dict = { "error_name": error_name, @@ -2392,6 +2411,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evKernelNotPowerOfTwo(check=False, **kwargs): + error_name = ErrorIf.KernelNotPowerOfTwo + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "kernel height and/or width not a power of two" + + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + if check: + shape = kwargs["input_shape"] + if len(shape) == 3: + valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2]) + error_result = not valid_kernel + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c29763b..fddf942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -255,7 +255,7 @@ class TosaTestGen: input_dtype=a.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -293,7 +293,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -333,7 +333,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -378,7 +378,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -414,7 +414,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -448,7 +448,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -487,7 +487,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -523,7 +523,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -582,7 +582,7 @@ class TosaTestGen: stride=stride, pad=pad, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -938,7 +938,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -980,7 +980,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1016,7 +1016,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1064,7 +1064,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1122,7 +1122,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1153,7 +1153,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1199,7 +1199,7 @@ class TosaTestGen: input_dtype=a[0].dtype, output_dtype=result_tens.dtype, inputs=a, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1250,7 +1250,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1283,7 +1283,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1318,7 +1318,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1356,7 +1356,7 @@ class TosaTestGen: perms=perms, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1391,7 +1391,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, start=start, size=size, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1425,7 +1425,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1474,7 +1474,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1519,7 +1519,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values_in.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1580,7 +1580,7 @@ class TosaTestGen: border=border, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -1628,7 +1628,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=val.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1774,7 +1774,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -2083,6 +2083,38 @@ class TosaTestGen: return acc_out + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): + results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) + + input_names = [val.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape=val.shape, + input_dtype=val.dtype, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + self.ser.addOperator(op["op"], input_names, output_names) + return results + def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): @@ -3897,6 +3929,27 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "rfft2d": { + "op": Op.RFFT2D, + "operands": (1, 0), + "rank": (3, 3), + "build_fcn": ( + build_rfft2d, + TosaTensorGen.tgRFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agNone, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + ), + }, } @@ -4717,3 +4770,26 @@ class OutputShaper: out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) + + @staticmethod + def rfft2dOp(serializer, rng, value, error_name=None): + outputs = [] + + input_shape = value.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1] + + output_dtype = value.dtype + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + incorrect_batch = input_shape[0] + rng.integers(1, 10) + output_shape = [incorrect_batch, *input_shape[1:]] + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs |