From 9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d Mon Sep 17 00:00:00 2001 From: evacha01 Date: Wed, 7 Feb 2024 11:21:55 +0000 Subject: FULL data gen mode for FP16 Signed-off-by: evacha01 Change-Id: I81bb322132daf25328a40342edc62d8e1db9edd6 --- reference_model/CMakeLists.txt | 2 + reference_model/src/generate/generate_entry.cc | 5 ++ .../src/generate/generate_full_range.cc | 59 ++++++++++++++++++++++ reference_model/src/generate/generate_full_range.h | 34 +++++++++++++ reference_model/src/generate/generate_utils.cc | 21 ++++++-- reference_model/src/generate/generate_utils.h | 15 ++++-- reference_model/src/verify/verify_abs_error.cc | 13 +++-- reference_model/src/verify/verify_utils.cc | 25 ++++++--- reference_model/test/generate_tests.cpp | 53 +++++++++++++++++++ .../schemavalidation/datagen-config.schema.json | 26 +++------- verif/generator/tosa_arg_gen.py | 56 +++++++++++++++++--- verif/generator/tosa_test_gen.py | 18 +++---- verif/generator/tosa_utils.py | 35 ++++++------- 13 files changed, 291 insertions(+), 71 deletions(-) create mode 100644 reference_model/src/generate/generate_full_range.cc create mode 100644 reference_model/src/generate/generate_full_range.h diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index 0f806fc..b780781 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -74,6 +74,7 @@ set(CXX_SOURCE src/generate/generate_dot_product.cc src/generate/generate_pseudo_random.cc src/generate/generate_fixed_data.cc + src/generate/generate_full_range.cc src/generate/generate_entry.cc src/generate/generate_utils.cc src/verify/verify_abs_error.cc @@ -177,6 +178,7 @@ add_library(tosa_reference_generate_lib SHARED src/generate/generate_dot_product.cc src/generate/generate_pseudo_random.cc src/generate/generate_fixed_data.cc + src/generate/generate_full_range.cc src/generate/generate_entry.cc src/generate/generate_utils.cc src/generate/generate_config.cc diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc index 91b2fc7..6f797b3 100644 --- a/reference_model/src/generate/generate_entry.cc +++ b/reference_model/src/generate/generate_entry.cc @@ -16,6 +16,7 @@ #include "generate_dot_product.h" #include "generate_fixed_data.h" +#include "generate_full_range.h" #include "generate_pseudo_random.h" #include "generate_utils.h" @@ -41,6 +42,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size) return generateFixedData(cfg, data, size); break; } + case GeneratorType::FullRange: { + return generateFullRange(cfg, data, size); + break; + } default: { WARNING("[Generator] Unsupported generation mode."); break; diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc new file mode 100644 index 0000000..d2a89da --- /dev/null +++ b/reference_model/src/generate/generate_full_range.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_full_range.h" +#include "half.hpp" + +namespace +{ + +template +bool generate(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size) +{ + const TosaReference::FullRangeInfo& frinfo = cfg.fullRangeInfo; + DataType value = frinfo.startVal; + + const auto T = TosaReference::numElementsFromShape(cfg.shape); + for (auto t = 0; t < T; ++t) + { + data[t] = value; + value++; + } + return true; +} +} // namespace + +namespace TosaReference +{ +bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size) +{ + // Check we support the operator + if (cfg.opType == Op::Op_UNKNOWN) + { + WARNING("[Generator][PR] Unknown operator."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP16: { + uint16_t* outData = reinterpret_cast(data); + return generate(cfg, outData, size); + } + default: + WARNING("[Generator][PR] Unsupported type."); + return false; + } +} +} // namespace TosaReference \ No newline at end of file diff --git a/reference_model/src/generate/generate_full_range.h b/reference_model/src/generate/generate_full_range.h new file mode 100644 index 0000000..df24160 --- /dev/null +++ b/reference_model/src/generate/generate_full_range.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERATE_FULL_RANGE_H_ +#define GENERATE_FULL_RANGE_H_ + +#include "generate_utils.h" + +namespace TosaReference +{ + +/// \brief Perform full range generation +/// +/// \param cfg Generator related meta-data +/// \param data Buffer to generate the data to +/// \param size Size of the buffer +/// +/// \return True on successful generation +bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size); + +}; // namespace TosaReference + +#endif // GENERATE_FULL_RANGE_H_ \ No newline at end of file diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index d0d0194..f31b443 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, { GeneratorType::Unknown, "UNKNOWN" }, { GeneratorType::PseudoRandom, "PSEUDO_RANDOM" }, { GeneratorType::DotProduct, "DOT_PRODUCT" }, - { GeneratorType::OpFullRange, "OP_FULL_RANGE" }, - { GeneratorType::OpBoundary, "OP_BOUNDARY" }, - { GeneratorType::OpSpecial, "OP_SPECIAL" }, + { GeneratorType::FullRange, "FULL_RANGE" }, + { GeneratorType::Boundary, "BOUNDARY" }, + { GeneratorType::Special, "SPECIAL" }, { GeneratorType::FixedData, "FIXED_DATA" }, }) @@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo) j.at("data").get_to(fixedDataInfo.data); } +void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo) +{ + if (j.contains("start_val")) + { + j.at("start_val").get_to(fullRangeInfo.startVal); + } +} + void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("data_type").get_to(cfg.dataType); @@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("fixed_data_info").get_to(cfg.fixedDataInfo); } + + //Set up defaults for fullRangeInfo + cfg.fullRangeInfo.startVal = 0; + if (j.contains("full_range_info")) + { + j.at("full_range_info").get_to(cfg.fullRangeInfo); + } } std::optional parseGenerateConfig(const char* json, const char* tensorName) diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h index 697b404..8ce9b0e 100644 --- a/reference_model/src/generate/generate_utils.h +++ b/reference_model/src/generate/generate_utils.h @@ -31,9 +31,9 @@ enum class GeneratorType Unknown, PseudoRandom, DotProduct, - OpFullRange, - OpBoundary, - OpSpecial, + FullRange, + Boundary, + Special, FixedData, }; @@ -74,6 +74,14 @@ struct FixedDataInfo std::vector data; }; +/// \brief Op specific generator meta-data +struct FullRangeInfo +{ + FullRangeInfo() = default; + + uint16_t startVal; +}; + /// \brief Generator configuration struct GenerateConfig { @@ -86,6 +94,7 @@ struct GenerateConfig DotProductInfo dotProductInfo; PseudoRandomInfo pseudoRandomInfo; FixedDataInfo fixedDataInfo; + FullRangeInfo fullRangeInfo; }; /// \brief Parse the generator config when given in JSON form diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index 125045e..64f86a3 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -30,12 +30,17 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg { const auto cfg = reinterpret_cast(cfgPtr); - double valBound = std::abs(referenceValue) * boundsValue; - if (cfg->lowerBound > 0) + double errorBound = 0.0; + if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0) { - valBound = std::max(cfg->lowerBound, valBound); + double valBound = std::abs(referenceValue) * boundsValue; + if (cfg->lowerBound > 0) + { + valBound = std::max(cfg->lowerBound, valBound); + } + errorBound = exp2(-AccPrecision::normal_frac / cfg->normalDivisor) * valBound; } - return exp2(-AccPrecision::normal_frac / cfg->normalDivisor) * valBound; + return errorBound; } } // namespace diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 50a98e5..d4657b3 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -356,21 +356,23 @@ bool validateData(const double* referenceData, TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation"); std::string warning, worstWarning; - double difference, worstDifference = 0.0; - size_t worstPosition; - bool compliant = true; + double worstDifference = 0.0; + // Set to invalid index + size_t worstIndex = T; + bool compliant = true; for (size_t i = 0; i < T; ++i) { - double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i]; - double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr); - bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning); + double difference = 0.0; + double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i]; + double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr); + bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning); if (!valid) { compliant = false; if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference)) { - worstPosition = i; + worstIndex = i; worstDifference = difference; worstWarning.assign(warning); if (std::isnan(difference)) @@ -379,11 +381,18 @@ bool validateData(const double* referenceData, break; } } + else if (std::abs(difference) == 0.0) + { + auto pos = indexToPosition(i, shape); + WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(), + positionToString(pos).c_str()); + return false; + } } } if (!compliant) { - auto pos = indexToPosition(worstPosition, shape); + auto pos = indexToPosition(worstIndex, shape); WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(), worstWarning.c_str()); } diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp index 564af4a..294705c 100644 --- a/reference_model/test/generate_tests.cpp +++ b/reference_model/test/generate_tests.cpp @@ -1556,4 +1556,57 @@ TEST_CASE("positive - FP32 rfft2d dot product (values -8, -7 & -6 from the end)" } } +TEST_CASE("positive - FP16 full range") +{ + std::string templateJsonCfg = R"({ + "tensors" : { + "input0" : { + "generator": "FULL_RANGE", + "data_type": "FP16", + "input_type": "VARIABLE", + "shape" : [ 48, 49, 47 ], + "input_pos": 0, + "op" : "CEIL", + "full_range_info": { + "start_val": _START_ + } + } + } + })"; + + const std::string tosaName = "input0"; + const size_t tosaElements = 48 * 49 * 47; + + SUBCASE("ceil - startVal 0") + { + std::string jsonCfg = templateJsonCfg; + update_json_template(jsonCfg, "_START_", "0"); + + std::vector buffer(tosaElements); + REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 2)); + std::vector expected = { 0, 1, 2 }; + check_output(buffer, expected); + + std::vector last_three(buffer.end() - std::min(3, buffer.size()), buffer.end()); + // To calculate last_expected: last value = tosaElements % 65535 - 1 + startVal + std::vector last_expected = { 45005, 45006, 45007 }; + check_output(last_three, last_expected); + } + SUBCASE("ceil - startVal 100") + { + std::string jsonCfg = templateJsonCfg; + update_json_template(jsonCfg, "_START_", "100"); + + std::vector buffer(tosaElements); + REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 2)); + std::vector expected = { 100, 101, 102 }; + check_output(buffer, expected); + + std::vector last_three(buffer.end() - std::min(3, buffer.size()), buffer.end()); + // To calculate last_expected: last value = tosaElements % 65535 - 1 + startVal + std::vector last_expected = { 45105, 45106, 45107 }; + check_output(last_three, last_expected); + } +} + TEST_SUITE_END(); // generate diff --git a/scripts/schemavalidation/datagen-config.schema.json b/scripts/schemavalidation/datagen-config.schema.json index a74d79f..19e8b62 100644 --- a/scripts/schemavalidation/datagen-config.schema.json +++ b/scripts/schemavalidation/datagen-config.schema.json @@ -22,7 +22,7 @@ "type": "object", "properties": { "generator": { - "description": "data generator name - PSEUDO_RANDOM, DOT_PRODUCT or OP_SPECIFIC", + "description": "data generator name - PSEUDO_RANDOM, DOT_PRODUCT, FULL_RANGE, BOUNDARY, or SPECIAL", "type": "string" }, "data_type": { @@ -120,33 +120,19 @@ "acc_type" ] }, - "op_specific_info": { - "description": "info required for the OP_SPECIFIC generator", + "full_range_info": { + "description": "info required for the FULL_RANGE generator", "type": "object", "properties": { - "sub_generator": { - "description": "sub generator type for this op - FULL, SPECIAL or BOUNDARY", - "type": "string" - }, - "offset": { - "description": "starting offset within the test data", + "start_val": { + "description": "starting value of the test data", "type": "integer", "minimum": 0 - }, - "attributes": { - "description": "attribute data from the op needed to compute the data", - "type": "object", - "properties": { - "TBD": { - "description": "Probably needed for RESCALE and MUL", - "type": "string" - } - } } }, "additionalProperties": false, - "required": [ "sub_generator" ] + "required": [ ] }, "fixed_data_info": { "description": "info required for FIXED_DATA generator", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index cbfffae..c596645 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -828,6 +828,12 @@ class TosaTensorValuesGen: if "axis" in argsDict: info["axis"] = int(argsDict["axis"]) tens_meta["dot_product_info"] = info + elif dg_type == gtu.DataGenType.FULL_RANGE: + info = {} + info["start_val"] = int( + testGen.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"]) + ) + tens_meta["full_range_info"] = info else: # TODO - other data gen type assert False, "TODO: support other data gen types" @@ -1795,7 +1801,7 @@ class TosaArgGen: pass @staticmethod - def _add_data_generators(testGen, opName, dtype, arg_list, error_name): + def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name): """Add extra tests for each type of data generator for this op.""" if ( error_name is None @@ -1820,7 +1826,16 @@ class TosaArgGen: new_arg_list = [] for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: - args_dict["dg_type"] = dg_type + + if dg_type == gtu.DataGenType.FULL_RANGE: + tensor_size = gtu.product(shapeList[0]) + if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]: + # Large enough tensor data size for full range, add a single test + num_test_sets = 0 + else: + # Not enough data size for full range of values, revert to random numbers + dg_type = gtu.DataGenType.PSEUDO_RANDOM + if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: num_test_sets = ( @@ -1829,6 +1844,7 @@ class TosaArgGen: else 0 ) else: + # Add single test for pseudo random num_test_sets = 0 elif dg_type == gtu.DataGenType.DOT_PRODUCT: @@ -1852,13 +1868,16 @@ class TosaArgGen: if num_test_sets > 0: for s in range(0, num_test_sets): - new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}" - new_args_dict = args_dict.copy() - new_args_dict["s"] = s - new_arg_list.append((new_arg_str, new_args_dict)) + set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}" + set_args_dict = args_dict.copy() + set_args_dict["s"] = s + set_args_dict["dg_type"] = dg_type + new_arg_list.append((set_arg_str, set_args_dict)) else: # Default is a single test - new_arg_list.append((arg_str, args_dict)) + new_args_dict = args_dict.copy() + new_args_dict["dg_type"] = dg_type + new_arg_list.append((arg_str, new_args_dict)) return new_arg_list @@ -1869,6 +1888,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, [("", {})], error_name, @@ -1883,6 +1903,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, [("", {"num_test_sets": 3})], error_name, @@ -1921,6 +1942,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2160,6 +2182,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtypes[0], arg_list, error_name, @@ -2194,6 +2217,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, input_dtype, arg_list, error_name, @@ -2246,6 +2270,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2402,6 +2427,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtypes[0], arg_list, error_name, @@ -2482,6 +2508,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2685,6 +2712,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2774,6 +2802,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2925,6 +2954,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, inDtype, arg_list, error_name, @@ -2947,6 +2977,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2967,6 +2998,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -2994,6 +3026,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3019,6 +3052,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3091,6 +3125,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3137,6 +3172,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3179,6 +3215,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3214,6 +3251,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3547,6 +3585,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3586,6 +3625,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3606,6 +3646,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, @@ -3624,6 +3665,7 @@ class TosaArgGen: arg_list = TosaArgGen._add_data_generators( testGen, opName, + shapeList, dtype, arg_list, error_name, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 415858c..a1f54c6 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -365,7 +365,7 @@ class TosaTestGen: if "ksb" in argsDict else int(argsDict["ks"]), } - elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL: + elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL: mode = gtu.ComplianceMode.FP_SPECIAL elif "compliance" in op and "ulp" in op["compliance"]: mode = gtu.ComplianceMode.ULP @@ -3959,7 +3959,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, }, "bitwise_not": { @@ -3996,7 +3996,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, "compliance": {"ulp": 0.5}, }, @@ -4055,7 +4055,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, }, "floor": { @@ -4075,7 +4075,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, "compliance": {"ulp": 0.5}, }, @@ -4096,7 +4096,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, "compliance": {"ulp": 5}, }, @@ -4137,7 +4137,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, }, "reciprocal": { @@ -4157,7 +4157,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, "compliance": {"ulp": 1.0}, }, @@ -4178,7 +4178,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), "data_gen": { - "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + "fp": (gtu.DataGenType.FULL_RANGE,), }, "compliance": {"ulp": 2}, }, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 384463f..6558bf8 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -13,22 +13,23 @@ MAX_RESIZE_DIMENSION = 16384 # Data type information dictionary # - str: filename abbreviation # - width: number of bytes needed for type +# - fullset: precalculated number of possible values in the data type's range, equal to 2^width # - json: JSON type string DTYPE_ATTRIBUTES = { - DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"}, - DType.INT4: {"str": "i4", "width": 4, "json": "INT4"}, - DType.INT8: {"str": "i8", "width": 8, "json": "INT8"}, - DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"}, - DType.INT16: {"str": "i16", "width": 16, "json": "INT16"}, - DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"}, - DType.INT32: {"str": "i32", "width": 32, "json": "INT32"}, - DType.INT48: {"str": "i48", "width": 48, "json": "INT48"}, - DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"}, - DType.FP16: {"str": "f16", "width": 16, "json": "FP16"}, - DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"}, - DType.FP32: {"str": "f32", "width": 32, "json": "FP32"}, - DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"}, - DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"}, + DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"}, + DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"}, + DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"}, + DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"}, + DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"}, + DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"}, + DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"}, + DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"}, + DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"}, + DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"}, + DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"}, + DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"}, + DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"}, + DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"}, } @@ -49,9 +50,9 @@ class DataGenType(IntEnum): PSEUDO_RANDOM = 0 DOT_PRODUCT = 1 - OP_BOUNDARY = 2 - OP_FULLSET = 3 - OP_SPECIAL = 4 + BOUNDARY = 2 + FULL_RANGE = 3 + SPECIAL = 4 FIXED_DATA = 5 -- cgit v1.2.1