aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-12 16:03:15 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-10-26 11:20:00 +0100
commitd41feb7138406832cfe045f41f254180e9c91ef4 (patch)
tree1539f57224123c34044ae1d1ad0e9bc468d26b1f /verif/generator/tosa_error_if.py
parentfc5e34e41afc07ea5ed03e3c5d4b5be92bef7fd7 (diff)
downloadreference_model-d41feb7138406832cfe045f41f254180e9c91ef4.tar.gz
Compliance testing support for MAX_POOL2D & PAD
Added Pseudo Random number generator in generate library. Enabled MAX_POOL2D, PAD FP32 tests to use new generator and compliance. Fixed verify library exact mode to expect reference data as FP64. Simplified tosa_verif_build_tests internal interfaces for new tests. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Icc0ffa924cf38107c3a212efd452c47a650c9d98
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py26
1 files changed, 19 insertions, 7 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index d490cf2..ed1a941 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2653,16 +2653,28 @@ class TosaInvalidValidator:
args = kwargs["args"]
- # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
- stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+ if isinstance(args, dict):
+ args_dict = args
+ else:
+ # Create args_dict from list elements
+ # TODO - Remove this once all NWHC operators agFunctions have been
+ # converted to args_dict output
+
+ # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one)
+ stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1)
+ args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]}
+ # Alias different info for each op
+ args_dict["kernel"] = args[pad_idx + 1]
+ args_dict["out_shape"] = args[pad_idx + 1]
+ args_dict["dilation"] = args[pad_idx + 1]
# Common info for all ops
- strides = args[stride_idx]
- padding = args[pad_idx]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
- kernel_shape = args[pad_idx + 1]
+ kernel_shape = args_dict["kernel"]
h = (
input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
) // strides[0]
@@ -2674,7 +2686,7 @@ class TosaInvalidValidator:
if opName.startswith("transpose_conv2d"):
# transpose_conv2d
- output_shape = args[pad_idx + 1]
+ output_shape = args_dict["out_shape"]
filter_shape = inputShapes[1]
kernel_shape = filter_shape[1:-1]
@@ -2703,7 +2715,7 @@ class TosaInvalidValidator:
if "conv2d" in opName or "conv3d" in opName:
# conv2d, conv3d, depthwise_conv2d
- dilations = args[pad_idx + 1]
+ dilations = args_dict["dilation"]
filter_shape = inputShapes[1]
kernel_shape = (
filter_shape[0:2]