aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-12-05 00:53:26 +0000
committerJerry-Ge <jerry.ge@arm.com>2023-12-15 17:33:02 +0000
commit28811d9abfcbfa717cab97b2dfbcd409dda36dc7 (patch)
tree066be20f0b7c07d712c804ce8b98cb8a6f8d19f2
parente1e611dc446fe597509b4b777bb474c059a1c0b6 (diff)
downloadreference_model-28811d9abfcbfa717cab97b2dfbcd409dda36dc7.tar.gz
Add basic framework test cases for dynamic shapes
- Added a basic infrastructure for allowing generate network with dynamic_shapes - Added tests cases for - batch_to_space - depth_to_space, space_to_depth - linear Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ie3f13231a74485df64b852f554cfe65e995f0d03
-rw-r--r--verif/frameworks/test_builder.py77
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py97
2 files changed, 172 insertions, 2 deletions
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 7b20cef..744dc38 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -886,6 +886,19 @@ class TBuilder:
def eval(self, a):
return tf.nn.log_softmax(a, name=self.result_name)
+ class DynamicLinear:
+ def __init__(self, dynamic_input_shape, name):
+ self.result_name = name
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape),
+ tf.keras.layers.Dense(units=5),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class MatMul:
def __init__(self, name):
self.result_name = name
@@ -1064,6 +1077,26 @@ class TBuilder:
transpose_op, self.block_shape, self.cropping, name=self.result_name
)
+ class DynamicBatchToSpace:
+ def __init__(self, block_shape, cropping, dynamic_input_shape, name):
+ self.result_name = name
+
+ dynamic_input_shape_with_batch = list(dynamic_input_shape)
+ dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
+ dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
+
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
+ tf.keras.layers.Lambda(
+ lambda x: tf.batch_to_space(x, block_shape, cropping, name=None)
+ ),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class SpaceToDepth:
def __init__(self, block_shape, name):
self.block_shape = block_shape
@@ -1072,6 +1105,28 @@ class TBuilder:
def eval(self, a):
return tf.nn.space_to_depth(a, self.block_shape, name=self.result_name)
+ class DynamicSpaceToDepth:
+ def __init__(self, dynamic_input_shape, name):
+ self.result_name = name
+
+ dynamic_input_shape_with_batch = list(dynamic_input_shape)
+ dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
+ dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
+
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
+ tf.keras.layers.Lambda(
+ lambda x: tf.nn.space_to_depth(
+ x, 2, data_format="NHWC", name=None
+ )
+ ),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class DepthToSpace:
def __init__(self, block_shape, name):
self.block_shape = block_shape
@@ -1080,6 +1135,28 @@ class TBuilder:
def eval(self, a):
return tf.nn.depth_to_space(a, self.block_shape, name=self.result_name)
+ class DynamicDepthToSpace:
+ def __init__(self, dynamic_input_shape, name):
+ self.result_name = name
+
+ dynamic_input_shape_with_batch = list(dynamic_input_shape)
+ dynamic_input_shape_no_batch = dynamic_input_shape_with_batch[1:]
+ dynamic_input_shape_no_batch = tuple(dynamic_input_shape_no_batch)
+
+ self.model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input(shape=dynamic_input_shape_no_batch),
+ tf.keras.layers.Lambda(
+ lambda x: tf.nn.depth_to_space(
+ x, 2, data_format="NHWC", name=None
+ )
+ ),
+ ]
+ )
+
+ def eval(self, a):
+ return self.model(a)
+
class OneHot:
def __init__(self, depth, axis, name):
self.depth = depth
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 9d666ab..1b187ba 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -709,6 +709,24 @@ TF_OP_LIST = {
"build_fcn": (TBuilder.LogSoftmax, TGen.tgBasic, ArgGen.agNone),
"types": TYPE_F,
},
+ "dynamic_linear": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.DynamicLinear, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tf": [],
+ "tflite": list(TYPE_F),
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(14, 19)],
+ },
+ # number of operands of tuples which spcifies which dim to set to None
+ # In this case, we have 1 input. So we have 1 tuple
+ # We're setting the first input's first dim to None
+ "dynamic_shape_dim": [
+ (0,),
+ ],
+ },
"matmul": {
"operands": (2, 0),
"build_fcn": (TBuilder.MatMul, TGen.tgMatmul, ArgGen.agNone),
@@ -771,16 +789,71 @@ TF_OP_LIST = {
"build_fcn": (TBuilder.BatchToSpace, TGen.tgBasic, ArgGen.agBatchToSpace),
"types": TYPE_F,
},
+ "dynamic_batch_to_space": {
+ "operands": (1, 0),
+ "build_fcn": (
+ TBuilder.DynamicBatchToSpace,
+ TGen.tgBasic,
+ ArgGen.agBatchToSpace,
+ ),
+ "types": TYPE_F,
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(8, 4, 4, 4)],
+ },
+ # number of operands of tuples which spcifies which dim to set to None
+ # In this case, we have 1 input. So we have 1 tuple
+ # We're setting the first input's 0th dim to None
+ "dynamic_shape_dim": [
+ (0,),
+ ],
+ },
"space_to_depth": {
"operands": (1, 0),
"build_fcn": (TBuilder.SpaceToDepth, TGen.tgBasic, ArgGen.agSpaceToDepth),
"types": TYPE_F,
},
+ "dynamic_space_to_depth": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.DynamicSpaceToDepth, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tf": [],
+ "tflite": list(TYPE_F),
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1, 32, 32, 8)],
+ },
+ # number of operands of tuples which spcifies which dim to set to None
+ # In this case, we have 1 input. So we have 1 tuple
+ # We're setting the first input's third dim to None
+ "dynamic_shape_dim": [
+ (3,),
+ ],
+ },
"depth_to_space": {
"operands": (1, 0),
"build_fcn": (TBuilder.DepthToSpace, TGen.tgBasic, ArgGen.agDepthToSpace),
"types": TYPE_F,
},
+ "dynamic_depth_to_space": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.DynamicDepthToSpace, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tf": [],
+ "tflite": list(TYPE_F),
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1, 1, 1, 4)],
+ },
+ # number of operands of tuples which spcifies which dim to set to None
+ # In this case, we have 1 input. So we have 1 tuple
+ # We're setting the first input's third dim to None
+ "dynamic_shape_dim": [
+ (3,),
+ ],
+ },
"one_hot": {
"operands": (3, 1),
"build_fcn": (TBuilder.OneHot, TGen.tgOneHot, ArgGen.agOneHot),
@@ -1095,9 +1168,22 @@ def run_unit_test(
placeholder_shapes = []
for idx, (name, val) in enumerate(placeholders):
+ input_shape = val.shape
+ try:
+ dynamic_shape_dim_tuples = op["dynamic_shape_dim"]
+ dim_tuple = dynamic_shape_dim_tuples[idx]
+ dim = dim_tuple[0]
+ input_shape = list(val.shape)
+ input_shape[dim] = None
+ dynamic_input_shape = tuple(input_shape)
+
+ addl_args.append(dynamic_input_shape)
+ except KeyError:
+ pass
+
placeholder_names.append(name)
placeholder_signatures = placeholder_signatures + (
- tf.TensorSpec(shape=val.shape, dtype=val.dtype, name=name),
+ tf.TensorSpec(shape=dynamic_input_shape, dtype=val.dtype, name=name),
)
placeholder_npy_filenames.append("{}.npy".format(name.split(":")[0]))
placeholder_shapes.append(val.shape)
@@ -1301,11 +1387,18 @@ def run_unit_test(
assert 0, "unknown tflite interpreter mode {}".format(
args.tflite_kernel_mode
)
- interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
+ # Prototype dynamic_shape testing
+ # Need to resize the input tensors to known shapes when evaluating
+ for idx, val in enumerate(placeholder_vals):
+ interpreter.resize_tensor_input(
+ input_details[idx]["index"], placeholder_shapes[idx]
+ )
+ interpreter.allocate_tensors()
+
assert len(input_details) == len(
placeholder_vals
), "number of placeholder mismatch"