aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-09-18 16:32:45 -0700
committerEric Kunze <eric.kunze@arm.com>2023-09-28 18:26:39 +0000
commitf9c0ceea99e197ab14f779eb51c5e1479dbeb4dd (patch)
tree9484444d75f38c533c214c6568671437b7fbddf4
parent41ebe72588b20b912eb8c9e082b2d66b37564ad3 (diff)
downloadreference_model-f9c0ceea99e197ab14f779eb51c5e1479dbeb4dd.tar.gz
Add 0-rank tensor support for concat in framework test
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Iff77091e4a57f487431ffbf7ac1c89301a153c8b
-rw-r--r--verif/frameworks/arg_gen.py4
-rw-r--r--verif/frameworks/tensor_gen.py8
-rw-r--r--verif/frameworks/test_builder.py12
-rw-r--r--verif/frameworks/test_gen_utils.py3
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py12
5 files changed, 34 insertions, 5 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index a25c205..c385274 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -45,6 +45,10 @@ class ArgGen:
@staticmethod
def agAxes(op, shapes, rng):
axes = []
+ if shapes == ():
+ axes.append(["_axis_0", [0]])
+ return axes
+
for i in range(-len(shapes), len(shapes), 1):
if i >= 0:
axes.append(["_axis_{}".format(i), [i]])
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index d0c0a0b..4370215 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -41,8 +41,12 @@ class TGen:
RAND_SHIFT_FACTOR = 0.5
if dtype == tf.float32:
- return np.float32(
- (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+ return (
+ np.float32(
+ (rng.random(size=shape) - RAND_SHIFT_FACTOR) * RAND_SCALE_FACTOR
+ )
+ if shape != ()
+ else np.float32(rng.random())
)
if dtype == tf.float16:
return np.float16(
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 3554e40..7b20cef 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -136,7 +136,11 @@ class TBuilder:
self.result_name = name
def eval(self, a, b):
- return tf.concat([a, b], self.axis, name=self.result_name)
+ return (
+ tf.concat([a, b], self.axis, name=self.result_name)
+ if a.shape != ()
+ else tf.stack([a, b], name=self.result_name)
+ )
class BitwiseAnd:
def __init__(self, name):
@@ -767,7 +771,11 @@ class TBuilder:
self.result_name = name
def eval(self, a, b, c, d):
- return tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+ return (
+ tf.concat([a, b, c, d], axis=self.axis, name=self.result_name)
+ if a.shape != ()
+ else tf.stack([a, b, c, d], name=self.result_name)
+ )
class Stack:
def __init__(self, axis, name):
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
index 6a59848..f31ac63 100644
--- a/verif/frameworks/test_gen_utils.py
+++ b/verif/frameworks/test_gen_utils.py
@@ -9,6 +9,9 @@ import tensorflow as tf
# Get a string name for a given shape
def get_shape_str(shape, dtype):
shape_name = None
+ if len(shape) == 0:
+ shape_name = "0"
+
for dim in shape:
shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ffe373b..9d666ab 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -174,6 +174,11 @@ TF_OP_LIST = {
"operands": (2, 0),
"build_fcn": (TBuilder.Concat, TGen.tgBasic, ArgGen.agAxes),
"types": TYPE_FI,
+ "rank": (0, 4),
+ "custom_shapes": {
+ "custom_shape_only": False,
+ "shape_list": [()],
+ },
},
"bitwise_and": {
"operands": (2, 0),
@@ -635,6 +640,11 @@ TF_OP_LIST = {
"operands": (4, 0),
"build_fcn": (TBuilder.Concatv2, TGen.tgBasic, ArgGen.agAxes),
"types": TYPE_FI,
+ "rank": (0, 4),
+ "custom_shapes": {
+ "custom_shape_only": False,
+ "shape_list": [()],
+ },
},
"stack": {
"operands": (4, 0),
@@ -1473,7 +1483,7 @@ def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_
shape_list = custom_shapes["shape_list"]
else:
shape_list = shape_list.copy()
- shape_list.append(custom_shapes["shape_list"])
+ shape_list.extend(custom_shapes["shape_list"])
except KeyError:
pass