diff options
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r-- | verif/frameworks/test_builder.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 744dc38..4bf8070 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import numpy as np import tensorflow as tf @@ -1059,6 +1059,26 @@ class TBuilder: a, self.block_shape, self.padding, name=self.result_name ) + class DynamicSpaceToBatch: + def __init__(self, block_shape, padding, dynamic_input_shape, name): + self.result_name = name + + dynamic_input_shape_with_batch = list(dynamic_input_shape) + dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:] + dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch) + + self.model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=dynamic_input_shape_no_batch), + tf.keras.layers.Lambda( + lambda x: tf.space_to_batch(x, block_shape, padding, name=None) + ), + ] + ) + + def eval(self, a): + return self.model(a) + class BatchToSpace: def __init__(self, block_shape, cropping, name): self.block_shape = block_shape |