aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_builder.py
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-01-23 09:40:37 -0800
committerTatWai Chong <tatwai.chong@arm.com>2024-01-23 14:34:14 -0800
commitbef907a0a1161df3cfc51c6401ffa061e10f430b (patch)
treea184ce7831d50dd8a595641bdabd2062f31e8364 /verif/frameworks/test_builder.py
parent6a46b251062dcd42bc9fa2bc9effad407747f64f (diff)
downloadreference_model-bef907a0a1161df3cfc51c6401ffa061e10f430b.tar.gz
Add dynamic space_to_batch to the framework test
Also fix the dimension mask out logic that only set batch dimension to unknown but others won't. Change-Id: I9e1d2c3bb1d24cba1242103aa2c7609ef0c2c0b3 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'verif/frameworks/test_builder.py')
-rw-r--r--verif/frameworks/test_builder.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 744dc38..4bf8070 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import tensorflow as tf
@@ -1059,6 +1059,26 @@ class TBuilder:
a, self.block_shape, self.padding, name=self.result_name
)
+ class DynamicSpaceToBatch:
+ def __init__(self, block_shape, padding, dynamic_input_shape, name):
+ self.result_name = name
+
+ dynamic_input_shape_with_batch = list(dynamic_input_shape)
+ dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
+ dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
+
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
+ tf.keras.layers.Lambda(
+ lambda x: tf.space_to_batch(x, block_shape, padding, name=None)
+ ),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class BatchToSpace:
def __init__(self, block_shape, cropping, name):
self.block_shape = block_shape