aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 1b187ba..3a9c0ca 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
@@ -1168,22 +1168,22 @@ def run_unit_test(
placeholder_shapes = []
for idx, (name, val) in enumerate(placeholders):
- input_shape = val.shape
+ input_shape = tuple(val.shape)
+
try:
dynamic_shape_dim_tuples = op["dynamic_shape_dim"]
dim_tuple = dynamic_shape_dim_tuples[idx]
dim = dim_tuple[0]
- input_shape = list(val.shape)
+ input_shape = list(input_shape)
input_shape[dim] = None
- dynamic_input_shape = tuple(input_shape)
- addl_args.append(dynamic_input_shape)
+ addl_args.append(tuple(input_shape))
except KeyError:
pass
placeholder_names.append(name)
placeholder_signatures = placeholder_signatures + (
- tf.TensorSpec(shape=dynamic_input_shape, dtype=val.dtype, name=name),
+ tf.TensorSpec(shape=input_shape, dtype=val.dtype, name=name),
)
placeholder_npy_filenames.append("{}.npy".format(name.split(":")[0]))
placeholder_shapes.append(val.shape)