aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
Diffstat (limited to 'verif')
-rw-r--r--verif/checker/tosa_result_checker.py16
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json5
-rw-r--r--verif/generator/tosa_arg_gen.py22
-rw-r--r--verif/generator/tosa_test_gen.py69
4 files changed, 88 insertions, 24 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 6948378..212c809 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -1,5 +1,5 @@
"""TOSA result checker script."""
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
@@ -55,9 +55,9 @@ def _print_result(color, msg):
def compliance_check(
- imp_result_path,
- ref_result_path,
- bnd_result_path,
+ imp_result_data,
+ ref_result_data,
+ bnd_result_data,
test_name,
compliance_config,
ofm_name,
@@ -78,14 +78,18 @@ def compliance_check(
return (TestResult.INTERNAL_ERROR, 0.0, msg)
success = vlib.verify_data(
- ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
+ ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data
)
if success:
_print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
return (TestResult.PASS, 0.0, "")
else:
_print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
- return (TestResult.MISMATCH, 0.0, "Non-compliance results found")
+ return (
+ TestResult.MISMATCH,
+ 0.0,
+ f"Non-compliance results found for {ofm_name}",
+ )
def test_check(
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 5e35e8b..067fab7 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -980,6 +980,7 @@
"profile": [
"tosa-mi"
],
+ "support_for": [ "lazy_data_gen" ],
"generation": {
"standard": {
"generator_args": [
@@ -987,13 +988,13 @@
"--target-dtype",
"fp32",
"--fp-values-range",
- "-2.0,2.0"
+ "-max,max"
],
[
"--target-dtype",
"fp32",
"--fp-values-range",
- "-2.0,2.0",
+ "-max,max",
"--target-shape",
"1,256,64",
"--target-shape",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b4939da..f6a46b4 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2798,9 +2798,27 @@ class TosaArgGen:
def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("inverseTrue", [True]))
- arg_list.append(("inverseFalse", [False]))
+ shape = shapeList[0]
+ dot_products = gtu.product(shape)
+ ks = 2 * shape[1] * shape[2] # 2*H*W
+ for inverse in (True, False):
+ args_dict = {
+ "dot_products": dot_products,
+ "shape": shape,
+ "ks": ks,
+ "acc_type": dtype,
+ "inverse": inverse,
+ }
+ arg_list.append((f"inverse{inverse}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
# Helper function for reshape. Gets some factors of a larger number.
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bfafd23..68a4e94 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -381,8 +381,35 @@ class TosaTestGen:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
- self.resultTensor = resultTensor
- self.complianceDict = complianceDict
+ if isinstance(resultTensor, list):
+ assert complianceDict is None or isinstance(complianceDict, list)
+ self.resultTensorList = resultTensor
+ self.complianceDictList = complianceDict
+ else:
+ self.resultTensorList = [resultTensor]
+ if complianceDict is None:
+ self.complianceDictList = None
+ else:
+ self.complianceDictList = [complianceDict]
+
+ def getComplianceInfo(self):
+ if self.complianceDictList is None:
+ return None
+ else:
+ tens_dict = {}
+ for tens, comp in zip(self.resultTensorList, self.complianceDictList):
+ if comp is not None:
+ tens_dict[tens.name] = comp
+
+ if tens_dict:
+ # Have some compliance data, so return the info
+ compliance = {
+ "version": "0.1",
+ "tensors": tens_dict,
+ }
+ else:
+ compliance = None
+ return compliance
def build_unary(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2491,12 +2518,16 @@ class TosaTestGen:
def build_fft2d(
self,
op,
- val1,
- val2,
- inverse,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 2
+ val1, val2 = inputs
+ inverse = args_dict["inverse"]
+
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
@@ -2537,7 +2568,16 @@ class TosaTestGen:
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(
+ op, val1.dtype, args_dict, res, error_name
+ )
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
@@ -2933,13 +2973,11 @@ class TosaTestGen:
if result:
# The test is valid, serialize it
- if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
- # Add the compliance meta data
- # NOTE: This currently expects only one result output
- tensMeta["compliance"] = {
- "version": "0.1",
- "tensors": {result.resultTensor.name: result.complianceDict},
- }
+ if isinstance(result, TosaTestGen.BuildInfo):
+ # Add the compliance meta data (if any)
+ compliance = result.getComplianceInfo()
+ if compliance:
+ tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
else:
# The test is not valid
@@ -4708,7 +4746,7 @@ class TosaTestGen:
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
@@ -4723,6 +4761,9 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
"rfft2d": {
"op": Op.RFFT2D,