aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/generate/generate_dot_product.cc88
-rw-r--r--reference_model/src/generate/generate_dot_product.h3
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc48
-rw-r--r--reference_model/src/generate/generate_utils.cc1
-rw-r--r--reference_model/src/ops/tensor_ops.cc35
-rw-r--r--reference_model/src/verify/verify_dot_product.cc40
-rw-r--r--reference_model/test/generate_tests.cpp118
-rw-r--r--verif/checker/tosa_result_checker.py16
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json5
-rw-r--r--verif/generator/tosa_arg_gen.py22
-rw-r--r--verif/generator/tosa_test_gen.py69
11 files changed, 392 insertions, 53 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index fe829e3..046007e 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -736,6 +736,92 @@ bool generateTransposeConv2D(const TosaReference::GenerateConfig& cfg,
return false;
}
}
+//---------------------------------------------------------------------------//
+// FFT2D //
+//---------------------------------------------------------------------------//
+
+template <typename DataType>
+bool generateFFT2DReal(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ DataType* data,
+ size_t size)
+{
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t H = cfg.shape[1];
+ const uint32_t W = cfg.shape[2];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t x = t % W;
+ uint32_t y = (t / W) % H;
+ uint32_t k = y * W + x;
+
+ data[t] = static_cast<DataType>(generator(k));
+ }
+ return true;
+}
+
+template <typename DataType>
+bool generateFFT2DImag(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ DataType* data,
+ size_t size)
+{
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t H = cfg.shape[1];
+ const uint32_t W = cfg.shape[2];
+
+ // The index expression of ((1*N+n)*H+y)*W+x in the spec equates to
+ // using the values after those used for the Real tensor, but we need
+ // to iterate through all those values to get to the Imaginary data
+ for (int64_t n = 0; n < 2; ++n)
+ {
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t x = t % W;
+ uint32_t y = (t / W) % H;
+ uint32_t k = y * W + x;
+
+ data[t] = static_cast<DataType>(generator(k));
+ }
+ }
+ return true;
+}
+
+bool generateFFT2D(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 3)
+ {
+ WARNING("[Generator][DP][FFT2D] Tensor shape expected 3 dimensions.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_FP32: {
+ float* outData = reinterpret_cast<float*>(data);
+ switch (cfg.inputPos)
+ {
+ case 0:
+ return generateFFT2DReal(cfg, generator, outData, size);
+ case 1:
+ return generateFFT2DImag(cfg, generator, outData, size);
+ default:
+ WARNING("[Generator][DP][FFT2D] Invalid input tensor slot position to operator.");
+ return false;
+ }
+ break;
+ }
+ default:
+ WARNING("[Generator][DP][FFT2D] Only supports FP32.");
+ return false;
+ }
+
+ return true;
+}
} // namespace
namespace TosaReference
@@ -772,6 +858,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
return generateDepthwiseConv2D(cfg, *generator, data, size);
case tosa::Op_TRANSPOSE_CONV2D:
return generateTransposeConv2D(cfg, *generator, data, size);
+ case tosa::Op_FFT2D:
+ return generateFFT2D(cfg, *generator, data, size);
default:
WARNING("[Generator][DP] Unsupported operator.");
return false;
diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h
index cd9d4ba..bf1b1ff 100644
--- a/reference_model/src/generate/generate_dot_product.h
+++ b/reference_model/src/generate/generate_dot_product.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ class IDotProductGenerator
public:
virtual float operator()(uint32_t k) = 0;
virtual ~IDotProductGenerator() = default;
+ virtual uint32_t nextIndex() = 0;
};
/// \brief Dot-product stage generator selector
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index 9ce32ff..b78be71 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@ public:
return pseudo;
}
- uint32_t index()
+ uint32_t nextIndex()
{
return _index;
}
@@ -101,6 +101,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -129,6 +134,10 @@ public:
else
return (_B * _B / (_KS + 1)) * v;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -158,6 +167,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -186,6 +199,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -229,6 +246,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -258,6 +280,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf
float B = getBoundParameter(cfg.dataType, dpinfo.accType);
if (B > 0.f)
{
+ auto param = cfg.inputPos;
+ if (cfg.opType == Op_FFT2D)
+ {
+ // We only use param of zero for FFT2D tensors
+ param = 0;
+ }
// Create the generator
switch (dpinfo.s)
{
case 0:
- return std::make_unique<GeneratorS0>(cfg.inputPos);
+ return std::make_unique<GeneratorS0>(param);
case 1:
- return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS1>(param, dpinfo.ks, B);
case 2:
- return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
+ return std::make_unique<GeneratorS2>(param, dpinfo.ks);
case 3:
- return std::make_unique<GeneratorS3>(cfg.inputPos);
+ return std::make_unique<GeneratorS3>(param);
case 4:
- return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS4>(param, dpinfo.ks, B);
case 5:
- return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS5>(param, dpinfo.ks, B);
default:
WARNING("[Generator][DP] Unsupported dot product test series for generator.");
return nullptr;
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index a8b472a..2e40b04 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -54,6 +54,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_ERF, "ERF" },
{ Op::Op_EXP, "EXP" },
{ Op::Op_FLOOR, "FLOOR" },
+ { Op::Op_FFT2D, "FFT2D" },
{ Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" },
{ Op::Op_GATHER, "GATHER" },
{ Op::Op_GREATER, "GREATER" },
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b9e2fbe..8d8dac7 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -1684,7 +1684,8 @@ int OpFFT2d<Dtype>::eval()
in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
- OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
+ OutEigenType sum_real, sum_imag, sign_val = 1.0;
+ OutEigenType a, a_cos, a_sin, v_ir;
if (attribute->inverse())
{
@@ -1715,11 +1716,33 @@ int OpFFT2d<Dtype>::eval()
{
OutEigenType val_real = in_real_val(n, iy, ix);
OutEigenType val_imag = in_imag_val(n, iy, ix);
- // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+ // Perform the periodic calculation in integer maths to keep
+ // the accuracy of the co-efficients similar for FP32 normal
+ // and FP64 precise mode
+ int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_real_height;
+ int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_real_width;
+
+ // Use explicit cast to ensure intermediate calculations are completed using OutEigenType
a = sign_val * 2 * M_PI *
- ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
- sum_real += val_real * cos(a) + val_imag * sin(a);
- sum_imag += -val_real * sin(a) + val_imag * cos(a);
+ ((OutEigenType)ay / in_real_height + (OutEigenType)ax / in_real_width);
+ // Calculate weight values
+ a_cos = cos(a);
+ a_sin = sin(a);
+ if (g_func_config.abs_mode)
+ {
+ // Bounded op - Use abs weight values
+ a_cos = std::abs(a_cos);
+ a_sin = std::abs(a_sin);
+ // Bounded op - Use abs real value for imaginary calc
+ v_ir = val_real;
+ }
+ else
+ {
+ // Normal op - Use negative real value for imaginary calc
+ v_ir = -val_real;
+ }
+ sum_real += val_real * a_cos + val_imag * a_sin;
+ sum_imag += v_ir * a_sin + val_imag * a_cos;
}
}
this->out_real->getTensor()(n, oy, ox) = sum_real;
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index a036cba..ea50573 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#include "half.hpp"
#include "verifiers.h"
+#include <cfloat>
#include <cmath>
#include <numeric>
#include <optional>
@@ -43,7 +44,8 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = (ref == 0.0) && (imp == 0.0);
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp);
+ WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref,
+ FLT_DIG, imp);
}
err = 0.0;
}
@@ -57,7 +59,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = std::abs(err) <= KS;
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS);
+ WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS);
}
}
@@ -66,8 +68,15 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
// Generic data validation function
template <typename AccType>
-bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+bool validateData(const double* ref,
+ const double* bnd,
+ const AccType* imp,
+ const std::vector<int32_t>& shape,
+ const DotProductVerifyInfo& cfg)
{
+ const size_t T = static_cast<size_t>(numElements(shape));
+ TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+
const int32_t S = cfg.s;
// NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias
// tensor is non-zero
@@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
for (size_t i = 0; i < T; ++i)
{
auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
- TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range");
+ if (!out_err)
+ {
+ auto pos = indexToPosition(i, shape);
+ TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range",
+ positionToString(pos).c_str());
+ }
out_err_sum += out_err.value();
out_err_sumsq += out_err.value() * out_err.value();
}
@@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
{
const double max_bias = 2 * sqrt(KS * T);
// Check error bias magnitude for data sets S which are not positive biased
- TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)",
- out_err_sum, max_bias);
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)",
+ DBL_DIG, out_err_sum, DBL_DIG, max_bias);
}
// Check error variance magnitude
const double max_error = 0.4 * KS * T;
- TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)",
- out_err_sumsq, max_error);
+ TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG,
+ out_err_sumsq, DBL_DIG, max_error);
return true;
}
} // namespace
@@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing");
TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing");
- // Get number of dot-product elements
- const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
- TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+ const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims);
const double* refData = reinterpret_cast<const double*>(ref->data);
const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
@@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
case tosa_datatype_fp32_t: {
const float* impData = reinterpret_cast<const float*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
case tosa_datatype_fp16_t: {
const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
default: {
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
index c45c5b6..4f62ede 100644
--- a/reference_model/test/generate_tests.cpp
+++ b/reference_model/test/generate_tests.cpp
@@ -1224,4 +1224,122 @@ TEST_CASE("positive - FP16 transpose_conv2d dot product (last 3 values)")
}
}
+void fft2d_test_FP32(const std::string tosaName,
+ const size_t tosaElements,
+ const std::string templateJsonCfg,
+ const std::string setStr,
+ const std::vector<uint32_t> lastExpected)
+{
+ std::string jsonCfg = templateJsonCfg;
+ update_json_template(jsonCfg, "_SET_", setStr);
+
+ std::vector<float> buffer(tosaElements);
+ REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 4));
+ // Get values at positions -8, -7 and -6 from the end
+ std::vector<float> last_three_ish(buffer.end() - 8, buffer.end() - 5);
+
+ check_output<float>(last_three_ish, lastExpected);
+}
+
+TEST_CASE("positive - FP32 fft2d dot product (values -8, -7 & -6 from the end)")
+{
+ std::string templateJsonCfg = R"({
+ "tensors" : {
+ "real" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 8, 2, 4 ],
+ "input_pos": 0,
+ "op" : "FFT2D",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 16,
+ "acc_type": "FP32"
+ }
+ },
+ "imag" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 8, 2, 4 ],
+ "input_pos": 1,
+ "op" : "FFT2D",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 16,
+ "acc_type": "FP32"
+ }
+ }
+ }
+ })";
+
+ const std::string tosaNameReal = "real";
+ const std::string tosaNameImag = "imag";
+ const size_t tosaElements = 8 * 2 * 4;
+
+ SUBCASE("fft2d, set 0, real")
+ {
+ std::vector<uint32_t> expected = { 0x0, 0x0, 0x3ee06867 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "0", expected);
+ }
+ SUBCASE("fft2d, set 0, imag")
+ {
+ std::vector<uint32_t> expected = { 0x3e6d1d36, 0x0, 0x0 };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "0", expected);
+ }
+ SUBCASE("fft2d, set 1, real")
+ {
+ // NOTE: Python test script produced 0x5e7219eb - so off by 1
+ std::vector<uint32_t> expected = { 0x5e18358e, 0x5e7219ec, 0x5e2beab2 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "1", expected);
+ }
+ SUBCASE("fft2d, set 1, imag")
+ {
+ std::vector<uint32_t> expected = { 0x5e71fbcc, 0x5e1bd27a, 0x5e46c84a };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "1", expected);
+ }
+ SUBCASE("fft2d, set 2, real")
+ {
+ std::vector<uint32_t> expected = { 0x3f800000, 0x3d704bae, 0x3e4443a6 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "2", expected);
+ }
+ SUBCASE("fft2d, set 2, imag")
+ {
+ std::vector<uint32_t> expected = { 0x3f800000, 0x3dacbd02, 0xbe26be6a };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "2", expected);
+ }
+ SUBCASE("fft2d, set 3, real")
+ {
+ // NOTE: Python test script produced 0x3de257cf, 0x3f144b53 - so off by 1
+ std::vector<uint32_t> expected = { 0x41800000, 0x3de257ce, 0x3f144b54 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "3", expected);
+ }
+ SUBCASE("fft2d, set 3, imag")
+ {
+ std::vector<uint32_t> expected = { 0x41800000, 0x3f86492c, 0xbf5bd4b3 };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "3", expected);
+ }
+ SUBCASE("fft2d, set 4, real")
+ {
+ std::vector<uint32_t> expected = { 0x0, 0x5d8c6475, 0x0 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "4", expected);
+ }
+ SUBCASE("fft2d, set 4, imag")
+ {
+ std::vector<uint32_t> expected = { 0xdca65b4f, 0x5c98b5d2, 0xdd14ddd8 };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "4", expected);
+ }
+ SUBCASE("fft2d, set 5, real")
+ {
+ std::vector<uint32_t> expected = { 0x5cb9bbd4, 0x5d8c0c21, 0x5daa1928 };
+ fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "5", expected);
+ }
+ SUBCASE("fft2d, set 5, imag")
+ {
+ std::vector<uint32_t> expected = { 0x5e708eb3, 0x5e2c1a78, 0x5ddbbc3f };
+ fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "5", expected);
+ }
+}
+
TEST_SUITE_END(); // generate
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 6948378..212c809 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -1,5 +1,5 @@
"""TOSA result checker script."""
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
@@ -55,9 +55,9 @@ def _print_result(color, msg):
def compliance_check(
- imp_result_path,
- ref_result_path,
- bnd_result_path,
+ imp_result_data,
+ ref_result_data,
+ bnd_result_data,
test_name,
compliance_config,
ofm_name,
@@ -78,14 +78,18 @@ def compliance_check(
return (TestResult.INTERNAL_ERROR, 0.0, msg)
success = vlib.verify_data(
- ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
+ ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data
)
if success:
_print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}")
return (TestResult.PASS, 0.0, "")
else:
_print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
- return (TestResult.MISMATCH, 0.0, "Non-compliance results found")
+ return (
+ TestResult.MISMATCH,
+ 0.0,
+ f"Non-compliance results found for {ofm_name}",
+ )
def test_check(
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 5e35e8b..067fab7 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -980,6 +980,7 @@
"profile": [
"tosa-mi"
],
+ "support_for": [ "lazy_data_gen" ],
"generation": {
"standard": {
"generator_args": [
@@ -987,13 +988,13 @@
"--target-dtype",
"fp32",
"--fp-values-range",
- "-2.0,2.0"
+ "-max,max"
],
[
"--target-dtype",
"fp32",
"--fp-values-range",
- "-2.0,2.0",
+ "-max,max",
"--target-shape",
"1,256,64",
"--target-shape",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b4939da..f6a46b4 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2798,9 +2798,27 @@ class TosaArgGen:
def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("inverseTrue", [True]))
- arg_list.append(("inverseFalse", [False]))
+ shape = shapeList[0]
+ dot_products = gtu.product(shape)
+ ks = 2 * shape[1] * shape[2] # 2*H*W
+ for inverse in (True, False):
+ args_dict = {
+ "dot_products": dot_products,
+ "shape": shape,
+ "ks": ks,
+ "acc_type": dtype,
+ "inverse": inverse,
+ }
+ arg_list.append((f"inverse{inverse}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
# Helper function for reshape. Gets some factors of a larger number.
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bfafd23..68a4e94 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -381,8 +381,35 @@ class TosaTestGen:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
- self.resultTensor = resultTensor
- self.complianceDict = complianceDict
+ if isinstance(resultTensor, list):
+ assert complianceDict is None or isinstance(complianceDict, list)
+ self.resultTensorList = resultTensor
+ self.complianceDictList = complianceDict
+ else:
+ self.resultTensorList = [resultTensor]
+ if complianceDict is None:
+ self.complianceDictList = None
+ else:
+ self.complianceDictList = [complianceDict]
+
+ def getComplianceInfo(self):
+ if self.complianceDictList is None:
+ return None
+ else:
+ tens_dict = {}
+ for tens, comp in zip(self.resultTensorList, self.complianceDictList):
+ if comp is not None:
+ tens_dict[tens.name] = comp
+
+ if tens_dict:
+ # Have some compliance data, so return the info
+ compliance = {
+ "version": "0.1",
+ "tensors": tens_dict,
+ }
+ else:
+ compliance = None
+ return compliance
def build_unary(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2491,12 +2518,16 @@ class TosaTestGen:
def build_fft2d(
self,
op,
- val1,
- val2,
- inverse,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 2
+ val1, val2 = inputs
+ inverse = args_dict["inverse"]
+
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
@@ -2537,7 +2568,16 @@ class TosaTestGen:
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(
+ op, val1.dtype, args_dict, res, error_name
+ )
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
@@ -2933,13 +2973,11 @@ class TosaTestGen:
if result:
# The test is valid, serialize it
- if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
- # Add the compliance meta data
- # NOTE: This currently expects only one result output
- tensMeta["compliance"] = {
- "version": "0.1",
- "tensors": {result.resultTensor.name: result.complianceDict},
- }
+ if isinstance(result, TosaTestGen.BuildInfo):
+ # Add the compliance meta data (if any)
+ compliance = result.getComplianceInfo()
+ if compliance:
+ tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
else:
# The test is not valid
@@ -4708,7 +4746,7 @@ class TosaTestGen:
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
@@ -4723,6 +4761,9 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
"rfft2d": {
"op": Op.RFFT2D,