diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/frameworks/test_builder.py | 22 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 26 |
2 files changed, 45 insertions, 3 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 744dc38..4bf8070 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -1059,6 +1059,26 @@ class TBuilder: a, self.block_shape, self.padding, name=self.result_name ) + class DynamicSpaceToBatch: + def __init__(self, block_shape, padding, dynamic_input_shape, name): + self.result_name = name + + dynamic_input_shape_with_batch = list(dynamic_input_shape) + dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:] + dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch) + + self.model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=dynamic_input_shape_no_batch), + tf.keras.layers.Lambda( + lambda x: tf.space_to_batch(x, block_shape, padding, name=None) + ), + ] + ) + + def eval(self, a): + return self.model(a) + class BatchToSpace: def __init__(self, block_shape, cropping, name): self.block_shape = block_shape diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 538f314..2a7d484 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -784,6 +784,25 @@ TF_OP_LIST = { "build_fcn": (TBuilder.SpaceToBatch, TGen.tgBasic, ArgGen.agSpaceToBatch), "types": TYPE_F, }, + "dynamic_space_to_batch": { + "operands": (1, 0), + "build_fcn": ( + TBuilder.DynamicSpaceToBatch, + TGen.tgBasic, + ArgGen.agSpaceToBatch, + ), + "types": TYPE_F, + "custom_shapes": { + "custom_shape_only": True, + "shape_list": [(13, 21, 3)], + }, + "dynamic_shape_dim": [ + ( + 0, + 1, + ), + ], + }, "batch_to_space": { "operands": (1, 0), "build_fcn": (TBuilder.BatchToSpace, TGen.tgBasic, ArgGen.agBatchToSpace), @@ -1174,9 +1193,12 @@ def run_unit_test( try: dynamic_shape_dim_tuples = op["dynamic_shape_dim"] dim_tuple = dynamic_shape_dim_tuples[idx] - dim = dim_tuple[0] input_shape = list(input_shape) - input_shape[dim] = None + + # Set the dimensions of input that are listed in the builder profile to unknown. + for dim in dim_tuple: + input_shape[dim] = None + # When any dimension size is unknown, mark the placeholder as dynamic type. placeholder_dynamic = True |