aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorevacha01 <evan.chandler@arm.com>2024-02-07 11:21:55 +0000
committerevacha01 <evan.chandler@arm.com>2024-03-07 12:06:38 +0000
commit9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d (patch)
tree55647ee0216800b621bd0b27277c6f895929ef3d
parent6e1e2bc06bff785e87577f24064bbc846300f8fd (diff)
downloadreference_model-9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d.tar.gz
FULL data gen mode for FP16
Signed-off-by: evacha01 <evan.chandler@arm.com> Change-Id: I81bb322132daf25328a40342edc62d8e1db9edd6
-rw-r--r--reference_model/CMakeLists.txt2
-rw-r--r--reference_model/src/generate/generate_entry.cc5
-rw-r--r--reference_model/src/generate/generate_full_range.cc59
-rw-r--r--reference_model/src/generate/generate_full_range.h34
-rw-r--r--reference_model/src/generate/generate_utils.cc21
-rw-r--r--reference_model/src/generate/generate_utils.h15
-rw-r--r--reference_model/src/verify/verify_abs_error.cc13
-rw-r--r--reference_model/src/verify/verify_utils.cc25
-rw-r--r--reference_model/test/generate_tests.cpp53
-rw-r--r--scripts/schemavalidation/datagen-config.schema.json26
-rw-r--r--verif/generator/tosa_arg_gen.py56
-rw-r--r--verif/generator/tosa_test_gen.py18
-rw-r--r--verif/generator/tosa_utils.py35
13 files changed, 291 insertions, 71 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 0f806fc..b780781 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -74,6 +74,7 @@ set(CXX_SOURCE
src/generate/generate_dot_product.cc
src/generate/generate_pseudo_random.cc
src/generate/generate_fixed_data.cc
+ src/generate/generate_full_range.cc
src/generate/generate_entry.cc
src/generate/generate_utils.cc
src/verify/verify_abs_error.cc
@@ -177,6 +178,7 @@ add_library(tosa_reference_generate_lib SHARED
src/generate/generate_dot_product.cc
src/generate/generate_pseudo_random.cc
src/generate/generate_fixed_data.cc
+ src/generate/generate_full_range.cc
src/generate/generate_entry.cc
src/generate/generate_utils.cc
src/generate/generate_config.cc
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
index 91b2fc7..6f797b3 100644
--- a/reference_model/src/generate/generate_entry.cc
+++ b/reference_model/src/generate/generate_entry.cc
@@ -16,6 +16,7 @@
#include "generate_dot_product.h"
#include "generate_fixed_data.h"
+#include "generate_full_range.h"
#include "generate_pseudo_random.h"
#include "generate_utils.h"
@@ -41,6 +42,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size)
return generateFixedData(cfg, data, size);
break;
}
+ case GeneratorType::FullRange: {
+ return generateFullRange(cfg, data, size);
+ break;
+ }
default: {
WARNING("[Generator] Unsupported generation mode.");
break;
diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc
new file mode 100644
index 0000000..d2a89da
--- /dev/null
+++ b/reference_model/src/generate/generate_full_range.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate_full_range.h"
+#include "half.hpp"
+
+namespace
+{
+
+template <typename DataType>
+bool generate(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size)
+{
+ const TosaReference::FullRangeInfo& frinfo = cfg.fullRangeInfo;
+ DataType value = frinfo.startVal;
+
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ for (auto t = 0; t < T; ++t)
+ {
+ data[t] = value;
+ value++;
+ }
+ return true;
+}
+} // namespace
+
+namespace TosaReference
+{
+bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size)
+{
+ // Check we support the operator
+ if (cfg.opType == Op::Op_UNKNOWN)
+ {
+ WARNING("[Generator][PR] Unknown operator.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_FP16: {
+ uint16_t* outData = reinterpret_cast<uint16_t*>(data);
+ return generate(cfg, outData, size);
+ }
+ default:
+ WARNING("[Generator][PR] Unsupported type.");
+ return false;
+ }
+}
+} // namespace TosaReference \ No newline at end of file
diff --git a/reference_model/src/generate/generate_full_range.h b/reference_model/src/generate/generate_full_range.h
new file mode 100644
index 0000000..df24160
--- /dev/null
+++ b/reference_model/src/generate/generate_full_range.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_FULL_RANGE_H_
+#define GENERATE_FULL_RANGE_H_
+
+#include "generate_utils.h"
+
+namespace TosaReference
+{
+
+/// \brief Perform full range generation
+///
+/// \param cfg Generator related meta-data
+/// \param data Buffer to generate the data to
+/// \param size Size of the buffer
+///
+/// \return True on successful generation
+bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_FULL_RANGE_H_ \ No newline at end of file
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d0d0194..f31b443 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::Unknown, "UNKNOWN" },
{ GeneratorType::PseudoRandom, "PSEUDO_RANDOM" },
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
- { GeneratorType::OpFullRange, "OP_FULL_RANGE" },
- { GeneratorType::OpBoundary, "OP_BOUNDARY" },
- { GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FullRange, "FULL_RANGE" },
+ { GeneratorType::Boundary, "BOUNDARY" },
+ { GeneratorType::Special, "SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
j.at("data").get_to(fixedDataInfo.data);
}
+void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
+{
+ if (j.contains("start_val"))
+ {
+ j.at("start_val").get_to(fullRangeInfo.startVal);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
}
+
+ //Set up defaults for fullRangeInfo
+ cfg.fullRangeInfo.startVal = 0;
+ if (j.contains("full_range_info"))
+ {
+ j.at("full_range_info").get_to(cfg.fullRangeInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 697b404..8ce9b0e 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -31,9 +31,9 @@ enum class GeneratorType
Unknown,
PseudoRandom,
DotProduct,
- OpFullRange,
- OpBoundary,
- OpSpecial,
+ FullRange,
+ Boundary,
+ Special,
FixedData,
};
@@ -74,6 +74,14 @@ struct FixedDataInfo
std::vector<int32_t> data;
};
+/// \brief Op specific generator meta-data
+struct FullRangeInfo
+{
+ FullRangeInfo() = default;
+
+ uint16_t startVal;
+};
+
/// \brief Generator configuration
struct GenerateConfig
{
@@ -86,6 +94,7 @@ struct GenerateConfig
DotProductInfo dotProductInfo;
PseudoRandomInfo pseudoRandomInfo;
FixedDataInfo fixedDataInfo;
+ FullRangeInfo fullRangeInfo;
};
/// \brief Parse the generator config when given in JSON form
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index 125045e..64f86a3 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -30,12 +30,17 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg
{
const auto cfg = reinterpret_cast<const AbsErrorVerifyInfo*>(cfgPtr);
- double valBound = std::abs(referenceValue) * boundsValue;
- if (cfg->lowerBound > 0)
+ double errorBound = 0.0;
+ if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0)
{
- valBound = std::max(cfg->lowerBound, valBound);
+ double valBound = std::abs(referenceValue) * boundsValue;
+ if (cfg->lowerBound > 0)
+ {
+ valBound = std::max(cfg->lowerBound, valBound);
+ }
+ errorBound = exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
}
- return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
+ return errorBound;
}
} // namespace
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 50a98e5..d4657b3 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -356,21 +356,23 @@ bool validateData(const double* referenceData,
TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation");
std::string warning, worstWarning;
- double difference, worstDifference = 0.0;
- size_t worstPosition;
- bool compliant = true;
+ double worstDifference = 0.0;
+ // Set to invalid index
+ size_t worstIndex = T;
+ bool compliant = true;
for (size_t i = 0; i < T; ++i)
{
- double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
- double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
- bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
+ double difference = 0.0;
+ double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
+ double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
+ bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
if (!valid)
{
compliant = false;
if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference))
{
- worstPosition = i;
+ worstIndex = i;
worstDifference = difference;
worstWarning.assign(warning);
if (std::isnan(difference))
@@ -379,11 +381,18 @@ bool validateData(const double* referenceData,
break;
}
}
+ else if (std::abs(difference) == 0.0)
+ {
+ auto pos = indexToPosition(i, shape);
+ WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(),
+ positionToString(pos).c_str());
+ return false;
+ }
}
}
if (!compliant)
{
- auto pos = indexToPosition(worstPosition, shape);
+ auto pos = indexToPosition(worstIndex, shape);
WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(),
worstWarning.c_str());
}
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
index 564af4a..294705c 100644
--- a/reference_model/test/generate_tests.cpp
+++ b/reference_model/test/generate_tests.cpp
@@ -1556,4 +1556,57 @@ TEST_CASE("positive - FP32 rfft2d dot product (values -8, -7 & -6 from the end)"
}
}
+TEST_CASE("positive - FP16 full range")
+{
+ std::string templateJsonCfg = R"({
+ "tensors" : {
+ "input0" : {
+ "generator": "FULL_RANGE",
+ "data_type": "FP16",
+ "input_type": "VARIABLE",
+ "shape" : [ 48, 49, 47 ],
+ "input_pos": 0,
+ "op" : "CEIL",
+ "full_range_info": {
+ "start_val": _START_
+ }
+ }
+ }
+ })";
+
+ const std::string tosaName = "input0";
+ const size_t tosaElements = 48 * 49 * 47;
+
+ SUBCASE("ceil - startVal 0")
+ {
+ std::string jsonCfg = templateJsonCfg;
+ update_json_template(jsonCfg, "_START_", "0");
+
+ std::vector<half_float::half> buffer(tosaElements);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 2));
+ std::vector<uint16_t> expected = { 0, 1, 2 };
+ check_output<half_float::half>(buffer, expected);
+
+ std::vector<half_float::half> last_three(buffer.end() - std::min<int>(3, buffer.size()), buffer.end());
+ // To calculate last_expected: last value = tosaElements % 65535 - 1 + startVal
+ std::vector<uint16_t> last_expected = { 45005, 45006, 45007 };
+ check_output<half_float::half>(last_three, last_expected);
+ }
+ SUBCASE("ceil - startVal 100")
+ {
+ std::string jsonCfg = templateJsonCfg;
+ update_json_template(jsonCfg, "_START_", "100");
+
+ std::vector<half_float::half> buffer(tosaElements);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 2));
+ std::vector<uint16_t> expected = { 100, 101, 102 };
+ check_output<half_float::half>(buffer, expected);
+
+ std::vector<half_float::half> last_three(buffer.end() - std::min<int>(3, buffer.size()), buffer.end());
+ // To calculate last_expected: last value = tosaElements % 65535 - 1 + startVal
+ std::vector<uint16_t> last_expected = { 45105, 45106, 45107 };
+ check_output<half_float::half>(last_three, last_expected);
+ }
+}
+
TEST_SUITE_END(); // generate
diff --git a/scripts/schemavalidation/datagen-config.schema.json b/scripts/schemavalidation/datagen-config.schema.json
index a74d79f..19e8b62 100644
--- a/scripts/schemavalidation/datagen-config.schema.json
+++ b/scripts/schemavalidation/datagen-config.schema.json
@@ -22,7 +22,7 @@
"type": "object",
"properties": {
"generator": {
- "description": "data generator name - PSEUDO_RANDOM, DOT_PRODUCT or OP_SPECIFIC",
+ "description": "data generator name - PSEUDO_RANDOM, DOT_PRODUCT, FULL_RANGE, BOUNDARY, or SPECIAL",
"type": "string"
},
"data_type": {
@@ -120,33 +120,19 @@
"acc_type"
]
},
- "op_specific_info": {
- "description": "info required for the OP_SPECIFIC generator",
+ "full_range_info": {
+ "description": "info required for the FULL_RANGE generator",
"type": "object",
"properties":
{
- "sub_generator": {
- "description": "sub generator type for this op - FULL, SPECIAL or BOUNDARY",
- "type": "string"
- },
- "offset": {
- "description": "starting offset within the test data",
+ "start_val": {
+ "description": "starting value of the test data",
"type": "integer",
"minimum": 0
- },
- "attributes": {
- "description": "attribute data from the op needed to compute the data",
- "type": "object",
- "properties": {
- "TBD": {
- "description": "Probably needed for RESCALE and MUL",
- "type": "string"
- }
- }
}
},
"additionalProperties": false,
- "required": [ "sub_generator" ]
+ "required": [ ]
},
"fixed_data_info": {
"description": "info required for FIXED_DATA generator",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index cbfffae..c596645 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -828,6 +828,12 @@ class TosaTensorValuesGen:
if "axis" in argsDict:
info["axis"] = int(argsDict["axis"])
tens_meta["dot_product_info"] = info
+ elif dg_type == gtu.DataGenType.FULL_RANGE:
+ info = {}
+ info["start_val"] = int(
+ testGen.randInt(0, gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["fullset"])
+ )
+ tens_meta["full_range_info"] = info
else:
# TODO - other data gen type
assert False, "TODO: support other data gen types"
@@ -1795,7 +1801,7 @@ class TosaArgGen:
pass
@staticmethod
- def _add_data_generators(testGen, opName, dtype, arg_list, error_name):
+ def _add_data_generators(testGen, opName, shapeList, dtype, arg_list, error_name):
"""Add extra tests for each type of data generator for this op."""
if (
error_name is None
@@ -1820,7 +1826,16 @@ class TosaArgGen:
new_arg_list = []
for dg_type in dataGenTypesList:
for arg_str, args_dict in arg_list:
- args_dict["dg_type"] = dg_type
+
+ if dg_type == gtu.DataGenType.FULL_RANGE:
+ tensor_size = gtu.product(shapeList[0])
+ if tensor_size >= gtu.DTYPE_ATTRIBUTES[dtype]["fullset"]:
+ # Large enough tensor data size for full range, add a single test
+ num_test_sets = 0
+ else:
+ # Not enough data size for full range of values, revert to random numbers
+ dg_type = gtu.DataGenType.PSEUDO_RANDOM
+
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
if error_name is None:
num_test_sets = (
@@ -1829,6 +1844,7 @@ class TosaArgGen:
else 0
)
else:
+ # Add single test for pseudo random
num_test_sets = 0
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
@@ -1852,13 +1868,16 @@ class TosaArgGen:
if num_test_sets > 0:
for s in range(0, num_test_sets):
- new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
- new_args_dict = args_dict.copy()
- new_args_dict["s"] = s
- new_arg_list.append((new_arg_str, new_args_dict))
+ set_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
+ set_args_dict = args_dict.copy()
+ set_args_dict["s"] = s
+ set_args_dict["dg_type"] = dg_type
+ new_arg_list.append((set_arg_str, set_args_dict))
else:
# Default is a single test
- new_arg_list.append((arg_str, args_dict))
+ new_args_dict = args_dict.copy()
+ new_args_dict["dg_type"] = dg_type
+ new_arg_list.append((arg_str, new_args_dict))
return new_arg_list
@@ -1869,6 +1888,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
[("", {})],
error_name,
@@ -1883,6 +1903,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
[("", {"num_test_sets": 3})],
error_name,
@@ -1921,6 +1942,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2160,6 +2182,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtypes[0],
arg_list,
error_name,
@@ -2194,6 +2217,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
input_dtype,
arg_list,
error_name,
@@ -2246,6 +2270,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2402,6 +2427,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtypes[0],
arg_list,
error_name,
@@ -2482,6 +2508,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2685,6 +2712,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2774,6 +2802,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2925,6 +2954,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
inDtype,
arg_list,
error_name,
@@ -2947,6 +2977,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2967,6 +2998,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -2994,6 +3026,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3019,6 +3052,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3091,6 +3125,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3137,6 +3172,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3179,6 +3215,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3214,6 +3251,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3547,6 +3585,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3586,6 +3625,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3606,6 +3646,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
@@ -3624,6 +3665,7 @@ class TosaArgGen:
arg_list = TosaArgGen._add_data_generators(
testGen,
opName,
+ shapeList,
dtype,
arg_list,
error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 415858c..a1f54c6 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -365,7 +365,7 @@ class TosaTestGen:
if "ksb" in argsDict
else int(argsDict["ks"]),
}
- elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
+ elif argsDict["dg_type"] == gtu.DataGenType.SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
elif "compliance" in op and "ulp" in op["compliance"]:
mode = gtu.ComplianceMode.ULP
@@ -3959,7 +3959,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
},
"bitwise_not": {
@@ -3996,7 +3996,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
"compliance": {"ulp": 0.5},
},
@@ -4055,7 +4055,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
},
"floor": {
@@ -4075,7 +4075,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
"compliance": {"ulp": 0.5},
},
@@ -4096,7 +4096,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
"compliance": {"ulp": 5},
},
@@ -4137,7 +4137,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
},
"reciprocal": {
@@ -4157,7 +4157,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
"compliance": {"ulp": 1.0},
},
@@ -4178,7 +4178,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
"data_gen": {
- "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ "fp": (gtu.DataGenType.FULL_RANGE,),
},
"compliance": {"ulp": 2},
},
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 384463f..6558bf8 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -13,22 +13,23 @@ MAX_RESIZE_DIMENSION = 16384
# Data type information dictionary
# - str: filename abbreviation
# - width: number of bytes needed for type
+# - fullset: precalculated number of possible values in the data type's range, equal to 2^width
# - json: JSON type string
DTYPE_ATTRIBUTES = {
- DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
- DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
- DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
- DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
- DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
- DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
- DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
- DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
- DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
- DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
- DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
- DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
- DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
- DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
+ DType.BOOL: {"str": "b", "width": 1, "fullset": 2, "json": "BOOL"},
+ DType.INT4: {"str": "i4", "width": 4, "fullset": 16, "json": "INT4"},
+ DType.INT8: {"str": "i8", "width": 8, "fullset": 256, "json": "INT8"},
+ DType.UINT8: {"str": "u8", "width": 8, "fullset": 256, "json": "UINT8"},
+ DType.INT16: {"str": "i16", "width": 16, "fullset": 65536, "json": "INT16"},
+ DType.UINT16: {"str": "u16", "width": 16, "fullset": 65536, "json": "UINT16"},
+ DType.INT32: {"str": "i32", "width": 32, "fullset": 1 << 32, "json": "INT32"},
+ DType.INT48: {"str": "i48", "width": 48, "fullset": 1 << 48, "json": "INT48"},
+ DType.SHAPE: {"str": "s", "width": 64, "fullset": 1 << 64, "json": "SHAPE"},
+ DType.FP16: {"str": "f16", "width": 16, "fullset": 65536, "json": "FP16"},
+ DType.BF16: {"str": "bf16", "width": 16, "fullset": 65536, "json": "BF16"},
+ DType.FP32: {"str": "f32", "width": 32, "fullset": 1 << 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "fullset": 256, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "fullset": 256, "json": "FP8E5M2"},
}
@@ -49,9 +50,9 @@ class DataGenType(IntEnum):
PSEUDO_RANDOM = 0
DOT_PRODUCT = 1
- OP_BOUNDARY = 2
- OP_FULLSET = 3
- OP_SPECIAL = 4
+ BOUNDARY = 2
+ FULL_RANGE = 3
+ SPECIAL = 4
FIXED_DATA = 5