aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 744dc38..4bf8070 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1059,6 +1059,26 @@ class TBuilder:
a, self.block_shape, self.padding, name=self.result_name
)
+ class DynamicSpaceToBatch:
+ def __init__(self, block_shape, padding, dynamic_input_shape, name):
+ self.result_name = name
+
+ dynamic_input_shape_with_batch = list(dynamic_input_shape)
+ dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
+ dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
+
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
+ tf.keras.layers.Lambda(
+ lambda x: tf.space_to_batch(x, block_shape, padding, name=None)
+ ),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class BatchToSpace:
def __init__(self, block_shape, cropping, name):
self.block_shape = block_shape