diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-12 16:03:15 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-26 11:20:00 +0100 |
commit | d41feb7138406832cfe045f41f254180e9c91ef4 (patch) | |
tree | 1539f57224123c34044ae1d1ad0e9bc468d26b1f /verif/generator/tosa_error_if.py | |
parent | fc5e34e41afc07ea5ed03e3c5d4b5be92bef7fd7 (diff) | |
download | reference_model-d41feb7138406832cfe045f41f254180e9c91ef4.tar.gz |
Compliance testing support for MAX_POOL2D & PAD
Added Pseudo Random number generator in generate library.
Enabled MAX_POOL2D, PAD FP32 tests to use new generator and compliance.
Fixed verify library exact mode to expect reference data as FP64.
Simplified tosa_verif_build_tests internal interfaces for new tests.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Icc0ffa924cf38107c3a212efd452c47a650c9d98
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index d490cf2..ed1a941 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -2653,16 +2653,28 @@ class TosaInvalidValidator: args = kwargs["args"] - # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one) - stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1) + if isinstance(args, dict): + args_dict = args + else: + # Create args_dict from list elements + # TODO - Remove this once all NWHC operators agFunctions have been + # converted to args_dict output + + # Skip accum_dtype arg (apart from MaxPool2D that doesn't have one) + stride_idx, pad_idx = (1, 2) if opName != "max_pool2d" else (0, 1) + args_dict = {"stride": args[stride_idx], "pad": args[pad_idx]} + # Alias different info for each op + args_dict["kernel"] = args[pad_idx + 1] + args_dict["out_shape"] = args[pad_idx + 1] + args_dict["dilation"] = args[pad_idx + 1] # Common info for all ops - strides = args[stride_idx] - padding = args[pad_idx] + strides = args_dict["stride"] + padding = args_dict["pad"] if opName.endswith("pool2d"): # avg_pool2d, max_pool2d - kernel_shape = args[pad_idx + 1] + kernel_shape = args_dict["kernel"] h = ( input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0] ) // strides[0] @@ -2674,7 +2686,7 @@ class TosaInvalidValidator: if opName.startswith("transpose_conv2d"): # transpose_conv2d - output_shape = args[pad_idx + 1] + output_shape = args_dict["out_shape"] filter_shape = inputShapes[1] kernel_shape = filter_shape[1:-1] @@ -2703,7 +2715,7 @@ class TosaInvalidValidator: if "conv2d" in opName or "conv3d" in opName: # conv2d, conv3d, depthwise_conv2d - dilations = args[pad_idx + 1] + dilations = args_dict["dilation"] filter_shape = inputShapes[1] kernel_shape = ( filter_shape[0:2] |