aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-18 17:22:21 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-02 23:22:09 +0000
commitd1a08ce27ef8d0f6cf77e1b864610aade06edc5c (patch)
tree777992f45d240361f898b1d21902c2a46c58235f
parentb0b9e33c3500bd8dc9b12ef012d4234b1245247a (diff)
downloadreference_model-d1a08ce27ef8d0f6cf77e1b864610aade06edc5c.tar.gz
Compliance mode testing for CONV2D
Added CONV2D data generation. Updated verify dot product check to latest specification. Updated test generator and python datagenerator library to create const files during test generation. Add support for compliance test sets to conformance test_select. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5be3b761a1e3ef259c058e493877cd5a89d5778b
-rw-r--r--reference_model/src/generate/generate_dot_product.cc115
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc2
-rw-r--r--reference_model/src/generate/generate_utils.cc1
-rw-r--r--reference_model/src/generate/generate_utils.h2
-rw-r--r--reference_model/src/verify/verify_dot_product.cc52
-rw-r--r--reference_model/src/verify/verify_utils.cc7
-rw-r--r--reference_model/src/verify/verify_utils.h36
-rw-r--r--reference_model/test/generate_tests.cpp162
-rw-r--r--scripts/schemavalidation/datagen-config.schema.json7
-rw-r--r--verif/conformance/test_select.py26
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json1
-rw-r--r--verif/generator/datagenerator.py59
-rw-r--r--verif/generator/tosa_arg_gen.py108
-rw-r--r--verif/generator/tosa_test_gen.py130
-rw-r--r--verif/generator/tosa_utils.py14
-rw-r--r--verif/tests/test_tosa_datagenerator.py14
16 files changed, 599 insertions, 137 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index cbfac4b..e6815ad 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -76,6 +76,119 @@ bool generateMatMul(const TosaReference::GenerateConfig& cfg,
return true;
}
+//---------------------------------------------------------------------------//
+// Conv2D //
+//---------------------------------------------------------------------------//
+
+bool generateConv2DInput(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.dotProductInfo.kernel.size() != 2 || cfg.dotProductInfo.kernel[0] <= 0 || cfg.dotProductInfo.kernel[1] <= 0)
+ {
+ WARNING("[Generator][DP][Conv2D][Input] Missing or incorrect kernel size information.");
+ return false;
+ }
+ if (cfg.shape.size() != 4)
+ {
+ WARNING("[Generator][DP][Conv2D][Input] Tensor shape expected 4 dimensions.");
+ return false;
+ }
+
+ float* input = reinterpret_cast<float*>(data);
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t IH = cfg.shape[1];
+ const uint32_t IW = cfg.shape[2];
+ const uint32_t IC = cfg.shape[3];
+ const uint32_t KH = cfg.dotProductInfo.kernel[0];
+ const uint32_t KW = cfg.dotProductInfo.kernel[1];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t ic = t % IC;
+ uint32_t ix = (t / IC) % IW;
+ uint32_t iy = ((t / IC) / IW) % IH;
+ uint32_t k = ((iy % KH) * KW + (ix % KW)) * IC + ic;
+
+ input[t] = generator(k);
+ }
+ return true;
+}
+
+bool generateConv2DWeight(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 4)
+ {
+ WARNING("[Generator][DP][Conv2D][Weight] Tensor shape expected 4 dimensions.");
+ return false;
+ }
+
+ float* weight = reinterpret_cast<float*>(data);
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t KH = cfg.shape[1];
+ const uint32_t KW = cfg.shape[2];
+ const uint32_t IC = cfg.shape[3];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t ic = t % IC;
+ uint32_t kx = (t / IC) % KW;
+ uint32_t ky = ((t / IC) / KW) % KH;
+ uint32_t k = (ky + KW * kx) * IC + ic;
+
+ weight[t] = generator(k);
+ }
+ return true;
+}
+
+bool generateConv2DBias(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 1)
+ {
+ WARNING("[Generator][DP][Conv2D][Bias] Tensor shape expected 1 dimension.");
+ return false;
+ }
+
+ float* bias = reinterpret_cast<float*>(data);
+ const uint32_t T = cfg.shape[0];
+
+ for (uint32_t t = 0; t < T; ++t)
+ {
+ bias[t] = generator(2);
+ }
+ return true;
+}
+
+bool generateConv2D(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.dataType != DType::DType_FP32)
+ {
+ WARNING("[Generator][DP][Conv2D] Only supports FP32.");
+ return false;
+ }
+ switch (cfg.inputPos)
+ {
+ case 0:
+ return generateConv2DInput(cfg, generator, data, size);
+ case 1:
+ return generateConv2DWeight(cfg, generator, data, size);
+ case 2:
+ return generateConv2DBias(cfg, generator, data, size);
+ default:
+ WARNING("[Generator][DP][Conv2D] Invalid input tensor slot position to operator.");
+ return false;
+ }
+}
} // namespace
namespace TosaReference
@@ -95,6 +208,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
{
case tosa::Op_MATMUL:
return generateMatMul(cfg, *generator, data, size);
+ case tosa::Op_CONV2D:
+ return generateConv2D(cfg, *generator, data, size);
default:
WARNING("[Generator][DP] Unsupported operator.");
return false;
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index 649e55e..53bef3a 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -242,7 +242,7 @@ public:
if (_p != P2)
return (_B / std::sqrt(_KS + 1)) * s;
else
- return (_B * _B / (_KS + 1)) * s;
+ return 0.f;
}
private:
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index bcbf9d7..d3bb076 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -41,6 +41,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_MATMUL, "MATMUL" },
{ Op::Op_MAX_POOL2D, "MAX_POOL2D" },
{ Op::Op_PAD, "PAD" },
+ { Op::Op_CONV2D, "CONV2D" },
})
} // namespace tosa
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 0239e98..7c55f1d 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -52,7 +52,7 @@ struct DotProductInfo
int32_t ks;
DType accType;
int32_t axis;
- std::array<int32_t, 2> kernel;
+ std::vector<int32_t> kernel;
};
/// \brief Pseudo random generator meta-data
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index 2a1d273..233c072 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -14,6 +14,7 @@
#include "func_debug.h"
#include "verifiers.h"
+#include "verify_utils.h"
#include <cmath>
#include <numeric>
@@ -24,22 +25,9 @@ namespace TosaReference
{
namespace
{
-
-// Accumulator precision
-template <typename T>
-struct AccPrecision;
-#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
-template <>
-struct AccPrecision<float>
-{
- static constexpr double precision = (double)(1 << 24);
- static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
-};
-#undef two_m42
-
// Generic element validation function
template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+std::optional<double> validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS)
{
double err = 0.0;
bool is_valid = true;
@@ -47,7 +35,11 @@ std::optional<double> validateElement(double ref, double bnd, AccType imp, size_
if (bnd == 0.0)
{
is_valid = (ref == 0.0) && (imp == 0.0);
- err = 0.0;
+ if (!is_valid)
+ {
+ WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp);
+ }
+ err = 0.0;
}
else if (std::isinf(static_cast<AccType>(bnd)))
{
@@ -58,11 +50,15 @@ std::optional<double> validateElement(double ref, double bnd, AccType imp, size_
else
{
// 0.0 < bnd < infinity
- const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
- const double imp_fp64 = static_cast<double>(imp);
- const double acc_prec_fp64 = AccPrecision<AccType>::precision;
- err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
- is_valid = std::abs(err) <= KS;
+ const double out_err_bnd =
+ std::max(bnd * exp2(-1 - AccPrecision<AccType>::normal_frac), AccPrecision<AccType>::normal_min);
+ const double imp_fp64 = static_cast<double>(imp);
+ err = (imp_fp64 - ref) / out_err_bnd;
+ is_valid = std::abs(err) <= KS;
+ if (!is_valid)
+ {
+ WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS);
+ }
}
return is_valid ? std::optional(err) : std::nullopt;
@@ -73,7 +69,8 @@ template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<A
bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
{
const int32_t S = cfg.s;
- // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ // NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias
+ // tensor is non-zero
const int32_t KS = cfg.ks;
double out_err_sum = 0.0;
@@ -81,7 +78,7 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
for (size_t i = 0; i < T; ++i)
{
- auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range");
out_err_sum += out_err.value();
out_err_sumsq += out_err.value() * out_err.value();
@@ -89,11 +86,16 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
if (S >= 3 && S <= 5)
{
+ const double max_bias = 2 * sqrt(KS * T);
+ out_err_sum = std::abs(out_err_sum);
// Check error bias magnitude for data sets S which are not positive biased
- TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "[DP] Bias magnitude is out of range");
+ TOSA_REF_REQUIRE(out_err_sum <= max_bias, "[DP] Bias magnitude (%g) is out of range (%g)", out_err_sum,
+ max_bias);
}
// Check error variance magnitude
- TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "[DP] Error variance magnitude is out of range");
+ const double max_error = 0.4 * KS * T;
+ TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)",
+ out_err_sumsq, max_error);
return true;
}
} // namespace
@@ -107,7 +109,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
// Get number of dot-product elements
const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
- TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+ TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
const double* refData = reinterpret_cast<const double*>(ref->data);
const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index ee11c41..43ecbe7 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -140,4 +140,11 @@ DType mapToDType(tosa_datatype_t dataType)
return DType_UNKNOWN;
}
+
+// Like const_exp2 but for use during runtime
+double exp2(int32_t n)
+{
+ TOSA_REF_REQUIRE(-1022 <= n && n <= 1023, " Invalid exponent value (%d)", n);
+ return const_exp2(n);
+}
} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index bbe4b4e..486ce19 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -23,10 +23,10 @@
#include <optional>
#include <vector>
-#define TOSA_REF_REQUIRE(COND, MESSAGE) \
+#define TOSA_REF_REQUIRE(COND, MESSAGE, ...) \
if (!(COND)) \
{ \
- WARNING("[Verifier]" MESSAGE "."); \
+ WARNING("[Verifier]" MESSAGE ".", ##__VA_ARGS__); \
return false; \
}
@@ -95,6 +95,38 @@ int64_t numElements(const std::vector<int32_t>& shape);
/// \brief Map API data-type to DType
DType mapToDType(tosa_datatype_t dataType);
+/// \brief Raise a value by the power of N or -N
+// For use during compile time - as no range check
+constexpr double const_exp2(int32_t n)
+{
+ double v = 1.0;
+ while (n > 0)
+ {
+ v = v * 2.0;
+ n--;
+ }
+ while (n < 0)
+ {
+ v = v / 2.0;
+ n++;
+ }
+ return v;
+}
+
+/// \brief Same as const_exp2 but with runtime range check of N
+double exp2(int32_t n);
+
+/// \brief Accuracy precision information
+template <typename T>
+struct AccPrecision;
+template <>
+struct AccPrecision<float>
+{
+ static constexpr double normal_min = const_exp2(-126);
+ static constexpr double normal_max = const_exp2(128) - const_exp2(127 - 23);
+ static constexpr int32_t normal_frac = 23;
+};
+
}; // namespace TosaReference
#endif // VERIFY_UTILS_H_
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
index c24a369..6173372 100644
--- a/reference_model/test/generate_tests.cpp
+++ b/reference_model/test/generate_tests.cpp
@@ -286,6 +286,168 @@ TEST_CASE("positive - FP32 matmul dot product (first 3 values)")
matmul_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 1, expected);
}
}
+
+void conv2d_test_FP32(const std::string tosaName[3],
+ const size_t tosaElements[3],
+ const std::string templateJsonCfg,
+ const std::string setStr,
+ int32_t param,
+ const std::vector<uint32_t> lastExpected)
+{
+ std::string jsonCfg = templateJsonCfg;
+ update_json_template(jsonCfg, "_SET_", setStr);
+
+ std::vector<float> buffer(tosaElements[param]);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName[param].c_str(), (void*)buffer.data(), tosaElements[param] * 4));
+ std::vector<float> last_three(buffer.end() - std::min<int>(3, buffer.size()), buffer.end());
+ check_output<float>(last_three, lastExpected);
+}
+
+TEST_CASE("positive - FP32 conv2d dot product (last 3 values)")
+{
+ std::string templateJsonCfg = R"({
+ "tensors" : {
+ "input" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 1, 8, 2, 4 ],
+ "input_pos": 0,
+ "op" : "CONV2D",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 16,
+ "acc_type": "FP32",
+ "kernel": [2, 2]
+ }
+ },
+ "weight" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "CONSTANT",
+ "shape" : [ 2, 2, 2, 4 ],
+ "input_pos": 1,
+ "op" : "CONV2D",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 16,
+ "acc_type": "FP32"
+ }
+ },
+ "bias" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "CONSTANT",
+ "shape" : [ 2 ],
+ "input_pos": 2,
+ "op" : "CONV2D",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 16,
+ "acc_type": "FP32"
+ }
+ }
+
+ }
+ })";
+
+ const std::string tosaName[3] = { "input", "weight", "bias" };
+ const size_t tosaElements[3] = { (1 * 8 * 2 * 4), (2 * 2 * 2 * 4), 2 };
+
+ SUBCASE("conv2d, set 0, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0xbf28bfda, 0xbe99cd47 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "0", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 0, param 1")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x3f648dfd, 0xbd4cb21c };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "0", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 0, param 2")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "0", 2, lastExpected);
+ }
+ SUBCASE("conv2d, set 1, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0x5e6f0400, 0x5e2f78e5, 0x5e62318d };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "1", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 1, param 1")
+ {
+ // NOTE: Python test script produced 0x5e6960b0 - so off by 1
+ std::vector<uint32_t> lastExpected = { 0x5e6960af, 0x5e6d0ca9, 0x5e0b8561 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "1", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 1, param 2")
+ {
+ // NOTE: Python test script produced 0x7cf260d0, 0x7d355432 - so off by 1
+ std::vector<uint32_t> lastExpected = { 0x7cf260d1, 0x7d355431 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "1", 2, lastExpected);
+ }
+ SUBCASE("conv2d, set 2, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0x3e7da8e9, 0x3df76a57, 0xbe338212 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "2", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 2, param 1")
+ {
+ std::vector<uint32_t> lastExpected = { 0x3daabbc5, 0xbe2f8909, 0xbdb806ec };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "2", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 2, param 2")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "2", 2, lastExpected);
+ }
+ SUBCASE("conv2d, set 3, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0xbee77fe5, 0x402141c5, 0xbda1b2ed };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "3", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 3, param 1")
+ {
+ // NOTE: Python test script produced 0xbe9947ac - so off by 1
+ std::vector<uint32_t> lastExpected = { 0x3f91e619, 0x3e9ac66b, 0xbe9947ad };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "3", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 3, param 2")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "3", 2, lastExpected);
+ }
+ SUBCASE("conv2d, set 4, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0xdd7e8575, 0x0, 0xde569ff3 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "4", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 4, param 1")
+ {
+ std::vector<uint32_t> lastExpected = { 0x5e2d6921, 0x5e13a014, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "4", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 4, param 2")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "4", 2, lastExpected);
+ }
+ SUBCASE("conv2d, set 5, param 0")
+ {
+ std::vector<uint32_t> lastExpected = { 0x5e719fb9, 0x5e6b329c, 0xdd7617d4 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 0, lastExpected);
+ }
+ SUBCASE("conv2d, set 5, param 1")
+ {
+ std::vector<uint32_t> lastExpected = { 0xde42f57a, 0x5dd68799, 0xde2ddfcb };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 1, lastExpected);
+ }
+ SUBCASE("conv2d, set 5, param 2")
+ {
+ std::vector<uint32_t> lastExpected = { 0x0, 0x0 };
+ conv2d_test_FP32(tosaName, tosaElements, templateJsonCfg, "5", 2, lastExpected);
+ }
+}
TEST_CASE("positive - pseudo random")
{
std::string templateJsonCfg = R"({
diff --git a/scripts/schemavalidation/datagen-config.schema.json b/scripts/schemavalidation/datagen-config.schema.json
index 01f9fad..68789f6 100644
--- a/scripts/schemavalidation/datagen-config.schema.json
+++ b/scripts/schemavalidation/datagen-config.schema.json
@@ -85,7 +85,8 @@
},
"ks": {
"description": "kernel size for this dot product operation",
- "type": "integer"
+ "type": "integer",
+ "minimum": 0
},
"acc_type": {
"description": "operator accumulator type (like tensor data_type)",
@@ -93,9 +94,9 @@
},
"kernel": {
"type": "array",
- "description": "kernel x, y sizes (for avg_pool2d)",
+ "description": "kernel x, y (and z) sizes",
"minItems": 2,
- "maxItems": 2,
+ "maxItems": 3,
"items": {
"description": "kernel dimension",
"type": "integer",
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index b7bbfc3..faefc85 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -125,6 +125,8 @@ class Operator:
# Working set of param_names - updated for negative tests
wks_param_names = None
+ COMPLIANCE_SETS = ("_s0", "_s1", "_s2", "_s3", "_s4", "_s5")
+
def __init__(
self,
test_dir: Path,
@@ -258,7 +260,15 @@ class Operator:
if (not negative and "ERRORIF" not in str(path)) or (
negative and "ERRORIF" in str(path)
):
- yield path
+ # Check for compliance test set paths
+ suffix = path.name[-3:]
+ if suffix in Operator.COMPLIANCE_SETS:
+ if suffix != Operator.COMPLIANCE_SETS[0]:
+ # Only return one of the test sets
+ continue
+ yield path.with_name(path.name[:-3])
+ else:
+ yield path
@classmethod
def get_test_paths(cls, test_dir: Path, negative):
@@ -343,7 +353,12 @@ class Operator:
for k in path_params:
unused_values[k].discard(path_params[k])
logger.debug(f"FOUND wanted: {path.name}")
- yield path
+ if path.exists():
+ yield path
+ else:
+ # Compliance test series - expand to all sets
+ for s in Operator.COMPLIANCE_SETS:
+ yield path.with_name(f"{path.name}{s}")
# search for tests that match any unused parameter values
for n, path in enumerate(sorted(list(unused_paths))):
@@ -359,7 +374,12 @@ class Operator:
unused_values[p].discard(path_params[p])
sparsity = self.sparsity[k] if k in self.sparsity else 0
logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
- yield path
+ if path.exists():
+ yield path
+ else:
+ # Compliance test series - expand to all sets
+ for s in Operator.COMPLIANCE_SETS:
+ yield path.with_name(f"{path.name}{s}")
break
if not self.ignore_missing:
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 9c18879..a090479 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -598,6 +598,7 @@
"profile": [
"tosa-mi"
],
+ "support_for": [ "lazy_data_gen" ],
"generation": {
"standard": {
"negative_dim_range": "1,10",
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py
index 408c83e..0d59084 100644
--- a/verif/generator/datagenerator.py
+++ b/verif/generator/datagenerator.py
@@ -6,7 +6,7 @@ import json
from pathlib import Path
import numpy as np
-from schemavalidation import schemavalidation
+import schemavalidation.schemavalidation as sch
class GenerateError(Exception):
@@ -14,7 +14,15 @@ class GenerateError(Exception):
class GenerateLibrary:
- """Python interface to the C generate library."""
+ """Python interface to the C generate library.
+
+ Simple usage to write out all input files:
+ set_config(test_desc)
+ write_numpy_files(test_path)
+
+ To get data buffers (for const data):
+ get_tensor_data(tensor_name)
+ """
def __init__(self, generate_lib_path):
"""Find the library and set up the interface."""
@@ -22,6 +30,8 @@ class GenerateLibrary:
if not self.lib_path.is_file():
raise GenerateError(f"Could not find generate library - {self.lib_path}")
+ self.schema_validator = sch.TestDescSchemaValidator()
+
self.test_desc = None
self.json_config = None
self.lib = ct.cdll.LoadLibrary(self.lib_path)
@@ -51,8 +61,7 @@ class GenerateLibrary:
raise GenerateError("No meta/data_gen section found in desc.json")
# Validate the config versus the schema
- tdsv = schemavalidation.TestDescSchemaValidator()
- tdsv.validate_config(test_desc)
+ self.schema_validator.validate_config(test_desc)
self.test_desc = test_desc
self.json_config = test_desc["meta"]["data_gen"]
@@ -72,25 +81,25 @@ class GenerateLibrary:
return buffer, size_bytes
- def _data_gen_write(
- self, test_path: Path, json_bytes: bytes, ifm_name: str, ifm_file: str
- ):
- """Generate the named tensor data and save it in numpy format."""
+ def _data_gen_array(self, json_config: str, tensor_name: str):
+ """Generate the named tensor data and return a numpy array."""
try:
- tensor = self.json_config["tensors"][ifm_name]
+ tensor = json_config["tensors"][tensor_name]
dtype = tensor["data_type"]
shape = tuple(tensor["shape"])
except KeyError as e:
raise GenerateError(
- f"Missing data in desc.json for input {ifm_name} - {repr(e)}"
+ f"Missing data in json config for input {tensor_name} - {repr(e)}"
)
buffer, size_bytes = self._create_buffer(dtype, shape)
buffer_ptr = ct.cast(buffer, ct.c_void_p)
+ json_bytes = bytes(json.dumps(json_config), "utf8")
+
result = self.tgd_generate_data(
ct.c_char_p(json_bytes),
- ct.c_char_p(bytes(ifm_name, "utf8")),
+ ct.c_char_p(bytes(tensor_name, "utf8")),
buffer_ptr,
ct.c_size_t(size_bytes),
)
@@ -100,11 +109,19 @@ class GenerateLibrary:
arr = np.ctypeslib.as_array(buffer)
arr = np.reshape(arr, shape)
+ return arr
+
+ def _data_gen_write(
+ self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
+ ):
+ """Generate the named tensor data and save it in numpy format."""
+ arr = self._data_gen_array(json_config, ifm_name)
+
file_name = test_path / ifm_file
np.save(file_name, arr)
def write_numpy_files(self, test_path: Path):
- """Write out all the specified tensors to numpy data files."""
+ """Write out all the desc.json input tensors to numpy data files."""
if self.test_desc is None or self.json_config is None:
raise GenerateError("Cannot write numpy files as no config set up")
@@ -114,12 +131,10 @@ class GenerateLibrary:
except KeyError as e:
raise GenerateError(f"Missing data in desc.json - {repr(e)}")
- json_bytes = bytes(json.dumps(self.json_config), "utf8")
-
failures = []
for iname, ifile in zip(ifm_names, ifm_files):
try:
- self._data_gen_write(test_path, json_bytes, iname, ifile)
+ self._data_gen_write(test_path, self.json_config, iname, ifile)
except GenerateError as e:
failures.append(
f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
@@ -128,6 +143,20 @@ class GenerateLibrary:
if len(failures) > 0:
raise GenerateError("\n".join(failures))
+ def get_tensor_data(self, tensor_name: str, json_config=None):
+ """Get a numpy array for a named tensor in the data_gen meta data."""
+ if json_config is None:
+ if self.json_config is None:
+ raise GenerateError("Cannot get tensor data as no config set up")
+ json_config = self.json_config
+ else:
+ # Validate the given config
+ self.schema_validator.validate_config(
+ json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
+ )
+
+ return self._data_gen_array(json_config, tensor_name)
+
def main(argv=None):
"""Simple command line interface for the data generator."""
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index f7837a0..32f4341 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -638,9 +638,9 @@ class TosaTensorValuesGen:
if (
error_name is not None
or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
- or opName in ("avg_pool2d",)
+ or "data_gen" not in testGen.TOSA_OP_LIST[opName]
):
- # Fall back to original path when dealing with unsupported types
+ # Fall back to original path when dealing with unsupported types or ops
# First turn off lazy data gen so we always produce data
lazy_data_gen = testGen.args.lazy_data_gen
@@ -660,7 +660,11 @@ class TosaTensorValuesGen:
# Create data generator meta-data
dg_type = argsDict["dg_type"]
- dg_tens_meta = {}
+ tens_data = {
+ "version": "0.1",
+ "tensors": {},
+ }
+ dg_tens_meta = tens_data["tensors"]
tens_ser_list = []
for idx, shape in enumerate(shapeList):
@@ -669,15 +673,12 @@ class TosaTensorValuesGen:
tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
tens_meta["shape"] = [int(i) for i in shape]
tens_meta["input_pos"] = idx
- tens_meta["op"] = opName.upper()
+ tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
if idx < pCount:
tens_meta["input_type"] = "VARIABLE"
- tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None)
else:
tens_meta["input_type"] = "CONSTANT"
- tens = testGen.ser.addConst(shape, dtypeList[idx], None)
- tens_ser_list.append(tens)
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
info = {}
@@ -691,23 +692,55 @@ class TosaTensorValuesGen:
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
info = {}
info["s"] = argsDict["s"]
- info["ks"] = argsDict["ks"]
- for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
- if key in argsDict:
- if key.endswith("_type"):
- info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"]
- else:
- info[key] = argsDict[key]
+ info["ks"] = int(argsDict["ks"])
+ if "acc_type" in argsDict:
+ # Convert type number into JSON name
+ info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
+ "json"
+ ]
+ if "kernel" in argsDict:
+ info["kernel"] = [int(k) for k in argsDict["kernel"]]
+ if "axis" in argsDict:
+ info["axis"] = int(argsDict["axis"])
tens_meta["dot_product_info"] = info
else:
# TODO - other data gen type
assert False, "TODO: support other data gen types"
+
+ # Using the finished generate config meta data - generate the data if
+ # needed and assign a tensor name from the serializer
+
+ # Need to generate data when not lazy or for the bias tensor as we need
+ # to work out if the bias data is non-zero for compliance
+ if not testGen.args.lazy_data_gen or (
+ idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
+ ):
+ # Give this tensor a temporary name until we get one from the serializer
+ temp_name = f"placeholder_{idx}"
+ dg_tens_meta[temp_name] = tens_meta
+ # Create data now using the temporary name to access meta details
+ data = testGen.dgl.get_tensor_data(temp_name, tens_data)
+ # Remove the item as we will give it the correct name later
+ del dg_tens_meta[temp_name]
+
+ if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
+ # The KS value used by compliance verification is altered when the
+ # bias data is non-zero
+ if max(abs(data)) > 0.0:
+ argsDict["ksb"] = argsDict["ks"] + 1
+
+ if testGen.args.lazy_data_gen:
+ data = None
+
+ if tens_meta["input_type"] == "VARIABLE":
+ tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
+ else:
+ tens = testGen.ser.addConst(shape, dtypeList[idx], data)
+
+ tens_ser_list.append(tens)
+ # Add the meta data to the list using the serializer tensor name
dg_tens_meta[tens.name] = tens_meta
- tens_data = {
- "version": "0.1",
- "tensors": dg_tens_meta,
- }
return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
@staticmethod
@@ -1206,8 +1239,11 @@ class TosaArgGen:
accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
- # Check the rank
+ # Op type checks
conv3d = opName.startswith("conv3d")
+ depthwise = opName.startswith("depthwise")
+
+ # Check the rank
rank = 5 if conv3d else 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == rank
@@ -1215,8 +1251,12 @@ class TosaArgGen:
# kernel rank omits channels
k_rank = rank - 2
- k_pos = 0 if opName.startswith("depthwise") else 1
+ k_pos = 0 if depthwise else 1
k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
+ # compliance size - KS
+ k_size = gtu.product(k_shape)
+ if not depthwise:
+ k_size *= ifm_shape[-1]
if not testGen.args.level8k:
# Generate comprehensive argument lists
@@ -1363,6 +1403,24 @@ class TosaArgGen:
# Test will consume too much memory - skip it
continue
+ # Compliance - number of dot product calculations
+ if depthwise:
+ # TODO - add support
+ dots = 0
+ else:
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, filter_shape[0])
+ )
+ args_dict = {
+ "acc_type": accum_dtype,
+ "stride": s,
+ "pad": p,
+ "dilation": d,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ }
+
# Support for larger values than 9 needs different delimiter
delim = "" if max(s + p + d) <= 9 else "x"
arg_list.append(
@@ -1373,11 +1431,19 @@ class TosaArgGen:
delim.join([str(x) for x in p]),
delim.join([str(x) for x in d]),
),
- [accum_dtype, s, p, d],
+ args_dict,
)
)
n += 1
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtypes[0],
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 17cbd8f..54b624e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -56,11 +56,9 @@ class TosaTestGen:
self.random_fp_high = max(args.tensor_fp_value_range)
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
- # Data generator library when not generating the data later
- if not args.lazy_data_gen:
- self.dgl = GenerateLibrary(args.generate_lib_path)
- else:
- self.dgl = None
+ # Data generator library is sometimes needed for compliance set up
+ # even if we are generating the data later (lazy_data_generation)
+ self.dgl = GenerateLibrary(args.generate_lib_path)
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -108,11 +106,6 @@ class TosaTestGen:
fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
json.dump(metaData["data_gen"], fd)
fd.write(')";\n\n')
- else:
- # Generate the data
- self.dgl.set_config(desc)
- self.dgl.write_numpy_files(path)
-
if "compliance" in metaData:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_compliance.cpp"
@@ -293,9 +286,15 @@ class TosaTestGen:
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
- def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
- if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype):
- # No compliance for error tests or other data types currently
+ def tensorComplianceMetaData(
+ self, op, inputType, argsDict, outputTensor, errorName
+ ):
+ if (
+ errorName
+ or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
+ or not gtu.dtypeIsSupportedByCompliance(inputType)
+ ):
+ # No compliance for error tests or unsupported types currently
return None
# Create compliance meta data for expected output tensor
@@ -308,7 +307,9 @@ class TosaTestGen:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
- "ks": argsDict["ks"],
+ "ks": int(argsDict["ksb"])
+ if "ksb" in argsDict
+ else int(argsDict["ks"]),
}
elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
@@ -741,31 +742,30 @@ class TosaTestGen:
error_name,
qinfo,
)
- if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, inputs[0].dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 4
- result_tens = OutputShaper.conv2dOp(
+ result_tensor = OutputShaper.conv2dOp(
self.ser,
self.rng,
ifm,
@@ -784,12 +784,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -802,7 +802,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -812,7 +812,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -820,22 +820,29 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv3d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
self.ser,
@@ -960,17 +967,19 @@ class TosaTestGen:
def build_depthwise_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
result_tens = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
@@ -1121,12 +1130,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -1431,12 +1437,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -2911,7 +2914,7 @@ class TosaTestGen:
"build_fcn": (
build_conv2d,
TosaTensorGen.tgConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2931,6 +2934,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -2941,7 +2947,7 @@ class TosaTestGen:
"build_fcn": (
build_conv3d,
TosaTensorGen.tgConv3D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2972,7 +2978,7 @@ class TosaTestGen:
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 14afaa7..7fc5b52 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -51,15 +51,21 @@ class DataGenType(IntEnum):
OP_SPECIAL = 4
-# Additional (optional) data for dot product data generator
-DG_DOT_PRODUCT_OPTIONAL_INFO = ("acc_type", "kernel", "axis")
-
-
def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
+ if isinstance(dtype, list) or isinstance(dtype, tuple):
+ dtype = dtype[0]
return dtype in (DType.FP32,)
+def getOpNameFromOpListName(opName):
+ """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
+ for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
+ if opName.startswith(name):
+ return name
+ return opName
+
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.
diff --git a/verif/tests/test_tosa_datagenerator.py b/verif/tests/test_tosa_datagenerator.py
index ba0235c..4f3d7fd 100644
--- a/verif/tests/test_tosa_datagenerator.py
+++ b/verif/tests/test_tosa_datagenerator.py
@@ -114,3 +114,17 @@ def test_generate_dot_product_check_fail_names():
for f in json_config["ifm_file"]:
file = TEST_DIR / f
assert not file.is_file()
+
+
+@pytest.mark.postcommit
+def test_generate_tensor_data_check():
+ glib = GenerateLibrary(GENERATE_LIB_PATH)
+ assert glib
+
+ json_config = JSON_DATAGEN_DOT_PRODUCT["meta"]["data_gen"]
+
+ for n in JSON_DATAGEN_DOT_PRODUCT["ifm_name"]:
+ arr = glib.get_tensor_data(n, json_config)
+
+ assert arr.shape == tuple(json_config["tensors"][n]["shape"])
+ assert arr.dtype == np.float32