aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-05-05 17:08:04 +0100
committerEric Kunze <eric.kunze@arm.com>2022-05-19 18:41:41 +0000
commit0e6218e22f25901aa208fbec44c9b14e14a68ba7 (patch)
tree6da7aa8972be69352555b5ac2709c607e2abecf8
parent7edb34c02614c48cc9e535c39198711d6692127d (diff)
downloadreference_model-0e6218e22f25901aa208fbec44c9b14e14a68ba7.tar.gz
Update framework test generation for ERROR_IF criteria
Update to tosa_verif_framework_generator to produce valid test ranges for pooling and convolution tests Fix up test filtering to only filter on test name not output directory Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ifff7e7604a37e8680d7237dc2d85cd806b20e384
-rw-r--r--verif/frameworks/arg_gen.py24
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py4
2 files changed, 27 insertions, 1 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 8feb9b2..fa4a652 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -92,6 +92,16 @@ class ArgGen:
):
continue
+ if (
+ (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride_h
+ != 0
+ ) or (
+ (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride_w
+ != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_dilat{}{}".format(
@@ -147,6 +157,14 @@ class ArgGen:
if shapes[1] % dilation_h != 0 or shapes[2] % dilation_w != 0:
continue
+ if (
+ (shapes[1] - 1 - (filter_h - 1) * dilation_h) % stride != 0
+ ) or (
+ (shapes[2] - 1 - (filter_w - 1) * dilation_w) % stride != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_dilat{}{}".format(
@@ -217,6 +235,12 @@ class ArgGen:
):
continue
+ if ((shapes[1] - kernel_h) % stride_h != 0) or (
+ (shapes[2] - kernel_w) % stride_w != 0
+ ):
+ # Not an exact integer output
+ continue
+
arg_list.append(
[
"_st{}{}_pad{}_kern{}{}".format(
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 8457d92..3b5d012 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -1147,7 +1147,9 @@ def build_const_net(
addl_args_tuple = arg_gen_fcn(op, curr_shape, rng)
for desc, addl_args in addl_args_tuple:
- if not filter or filter.search(test_dir + desc):
+ # Only filter on the full test_name, not the output directory
+ _, test_name = os.path.split(test_dir + desc)
+ if not filter or filter.search(test_name):
unit_test_args.append(
[
op_name,