diff options
-rw-r--r-- | reference_model/src/generate/generate_dot_product.cc | 88 | ||||
-rw-r--r-- | reference_model/src/generate/generate_dot_product.h | 3 | ||||
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 48 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 1 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 35 | ||||
-rw-r--r-- | reference_model/src/verify/verify_dot_product.cc | 40 | ||||
-rw-r--r-- | reference_model/test/generate_tests.cpp | 118 | ||||
-rw-r--r-- | verif/checker/tosa_result_checker.py | 16 | ||||
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 5 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 22 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 69 |
11 files changed, 392 insertions, 53 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc index fe829e3..046007e 100644 --- a/reference_model/src/generate/generate_dot_product.cc +++ b/reference_model/src/generate/generate_dot_product.cc @@ -736,6 +736,92 @@ bool generateTransposeConv2D(const TosaReference::GenerateConfig& cfg, return false; } } +//---------------------------------------------------------------------------// +// FFT2D // +//---------------------------------------------------------------------------// + +template <typename DataType> +bool generateFFT2DReal(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast<DataType>(generator(k)); + } + return true; +} + +template <typename DataType> +bool generateFFT2DImag(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + // The index expression of ((1*N+n)*H+y)*W+x in the spec equates to + // using the values after those used for the Real tensor, but we need + // to iterate through all those values to get to the Imaginary data + for (int64_t n = 0; n < 2; ++n) + { + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast<DataType>(generator(k)); + } + } + return true; +} + +bool generateFFT2D(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + if (cfg.shape.size() != 3) + { + WARNING("[Generator][DP][FFT2D] Tensor shape expected 3 dimensions."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP32: { + float* outData = reinterpret_cast<float*>(data); + switch (cfg.inputPos) + { + case 0: + return generateFFT2DReal(cfg, generator, outData, size); + case 1: + return generateFFT2DImag(cfg, generator, outData, size); + default: + WARNING("[Generator][DP][FFT2D] Invalid input tensor slot position to operator."); + return false; + } + break; + } + default: + WARNING("[Generator][DP][FFT2D] Only supports FP32."); + return false; + } + + return true; +} } // namespace namespace TosaReference @@ -772,6 +858,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size) return generateDepthwiseConv2D(cfg, *generator, data, size); case tosa::Op_TRANSPOSE_CONV2D: return generateTransposeConv2D(cfg, *generator, data, size); + case tosa::Op_FFT2D: + return generateFFT2D(cfg, *generator, data, size); default: WARNING("[Generator][DP] Unsupported operator."); return false; diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h index cd9d4ba..bf1b1ff 100644 --- a/reference_model/src/generate/generate_dot_product.h +++ b/reference_model/src/generate/generate_dot_product.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ class IDotProductGenerator public: virtual float operator()(uint32_t k) = 0; virtual ~IDotProductGenerator() = default; + virtual uint32_t nextIndex() = 0; }; /// \brief Dot-product stage generator selector diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index 9ce32ff..b78be71 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ public: return pseudo; } - uint32_t index() + uint32_t nextIndex() { return _index; } @@ -101,6 +101,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -129,6 +134,10 @@ public: else return (_B * _B / (_KS + 1)) * v; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -158,6 +167,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -186,6 +199,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -229,6 +246,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -258,6 +280,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf float B = getBoundParameter(cfg.dataType, dpinfo.accType); if (B > 0.f) { + auto param = cfg.inputPos; + if (cfg.opType == Op_FFT2D) + { + // We only use param of zero for FFT2D tensors + param = 0; + } // Create the generator switch (dpinfo.s) { case 0: - return std::make_unique<GeneratorS0>(cfg.inputPos); + return std::make_unique<GeneratorS0>(param); case 1: - return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS1>(param, dpinfo.ks, B); case 2: - return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks); + return std::make_unique<GeneratorS2>(param, dpinfo.ks); case 3: - return std::make_unique<GeneratorS3>(cfg.inputPos); + return std::make_unique<GeneratorS3>(param); case 4: - return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS4>(param, dpinfo.ks, B); case 5: - return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS5>(param, dpinfo.ks, B); default: WARNING("[Generator][DP] Unsupported dot product test series for generator."); return nullptr; diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index a8b472a..2e40b04 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -54,6 +54,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_ERF, "ERF" }, { Op::Op_EXP, "EXP" }, { Op::Op_FLOOR, "FLOOR" }, + { Op::Op_FFT2D, "FFT2D" }, { Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" }, { Op::Op_GATHER, "GATHER" }, { Op::Op_GREATER, "GREATER" }, diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b9e2fbe..8d8dac7 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1684,7 +1684,8 @@ int OpFFT2d<Dtype>::eval() in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width, out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width); - OutEigenType sum_real, sum_imag, a, sign_val = 1.0; + OutEigenType sum_real, sum_imag, sign_val = 1.0; + OutEigenType a, a_cos, a_sin, v_ir; if (attribute->inverse()) { @@ -1715,11 +1716,33 @@ int OpFFT2d<Dtype>::eval() { OutEigenType val_real = in_real_val(n, iy, ix); OutEigenType val_imag = in_imag_val(n, iy, ix); - // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + // Perform the periodic calculation in integer maths to keep + // the accuracy of the co-efficients similar for FP32 normal + // and FP64 precise mode + int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_real_height; + int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_real_width; + + // Use explicit cast to ensure intermediate calculations are completed using OutEigenType a = sign_val * 2 * M_PI * - ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width); - sum_real += val_real * cos(a) + val_imag * sin(a); - sum_imag += -val_real * sin(a) + val_imag * cos(a); + ((OutEigenType)ay / in_real_height + (OutEigenType)ax / in_real_width); + // Calculate weight values + a_cos = cos(a); + a_sin = sin(a); + if (g_func_config.abs_mode) + { + // Bounded op - Use abs weight values + a_cos = std::abs(a_cos); + a_sin = std::abs(a_sin); + // Bounded op - Use abs real value for imaginary calc + v_ir = val_real; + } + else + { + // Normal op - Use negative real value for imaginary calc + v_ir = -val_real; + } + sum_real += val_real * a_cos + val_imag * a_sin; + sum_imag += v_ir * a_sin + val_imag * a_cos; } } this->out_real->getTensor()(n, oy, ox) = sum_real; diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc index a036cba..ea50573 100644 --- a/reference_model/src/verify/verify_dot_product.cc +++ b/reference_model/src/verify/verify_dot_product.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "half.hpp" #include "verifiers.h" +#include <cfloat> #include <cmath> #include <numeric> #include <optional> @@ -43,7 +44,8 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT is_valid = (ref == 0.0) && (imp == 0.0); if (!is_valid) { - WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp); + WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref, + FLT_DIG, imp); } err = 0.0; } @@ -57,7 +59,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT is_valid = std::abs(err) <= KS; if (!is_valid) { - WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS); + WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS); } } @@ -66,8 +68,15 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT // Generic data validation function template <typename AccType> -bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg) +bool validateData(const double* ref, + const double* bnd, + const AccType* imp, + const std::vector<int32_t>& shape, + const DotProductVerifyInfo& cfg) { + const size_t T = static_cast<size_t>(numElements(shape)); + TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const int32_t S = cfg.s; // NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias // tensor is non-zero @@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size for (size_t i = 0; i < T; ++i) { auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS); - TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range"); + if (!out_err) + { + auto pos = indexToPosition(i, shape); + TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range", + positionToString(pos).c_str()); + } out_err_sum += out_err.value(); out_err_sumsq += out_err.value() * out_err.value(); } @@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size { const double max_bias = 2 * sqrt(KS * T); // Check error bias magnitude for data sets S which are not positive biased - TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)", - out_err_sum, max_bias); + TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)", + DBL_DIG, out_err_sum, DBL_DIG, max_bias); } // Check error variance magnitude const double max_error = 0.4 * KS * T; - TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)", - out_err_sumsq, max_error); + TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG, + out_err_sumsq, DBL_DIG, max_error); return true; } } // namespace @@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing"); TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing"); - // Get number of dot-product elements - const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims)); - TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims); const double* refData = reinterpret_cast<const double*>(ref->data); const double* refBndData = reinterpret_cast<const double*>(refBnd->data); @@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* case tosa_datatype_fp32_t: { const float* impData = reinterpret_cast<const float*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } case tosa_datatype_fp16_t: { const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } default: { diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp index c45c5b6..4f62ede 100644 --- a/reference_model/test/generate_tests.cpp +++ b/reference_model/test/generate_tests.cpp @@ -1224,4 +1224,122 @@ TEST_CASE("positive - FP16 transpose_conv2d dot product (last 3 values)") } } +void fft2d_test_FP32(const std::string tosaName, + const size_t tosaElements, + const std::string templateJsonCfg, + const std::string setStr, + const std::vector<uint32_t> lastExpected) +{ + std::string jsonCfg = templateJsonCfg; + update_json_template(jsonCfg, "_SET_", setStr); + + std::vector<float> buffer(tosaElements); + REQUIRE(tgd_generate_data(jsonCfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaElements * 4)); + // Get values at positions -8, -7 and -6 from the end + std::vector<float> last_three_ish(buffer.end() - 8, buffer.end() - 5); + + check_output<float>(last_three_ish, lastExpected); +} + +TEST_CASE("positive - FP32 fft2d dot product (values -8, -7 & -6 from the end)") +{ + std::string templateJsonCfg = R"({ + "tensors" : { + "real" : { + "generator": "DOT_PRODUCT", + "data_type": "FP32", + "input_type": "VARIABLE", + "shape" : [ 8, 2, 4 ], + "input_pos": 0, + "op" : "FFT2D", + "dot_product_info": { + "s": _SET_, + "ks": 16, + "acc_type": "FP32" + } + }, + "imag" : { + "generator": "DOT_PRODUCT", + "data_type": "FP32", + "input_type": "VARIABLE", + "shape" : [ 8, 2, 4 ], + "input_pos": 1, + "op" : "FFT2D", + "dot_product_info": { + "s": _SET_, + "ks": 16, + "acc_type": "FP32" + } + } + } + })"; + + const std::string tosaNameReal = "real"; + const std::string tosaNameImag = "imag"; + const size_t tosaElements = 8 * 2 * 4; + + SUBCASE("fft2d, set 0, real") + { + std::vector<uint32_t> expected = { 0x0, 0x0, 0x3ee06867 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "0", expected); + } + SUBCASE("fft2d, set 0, imag") + { + std::vector<uint32_t> expected = { 0x3e6d1d36, 0x0, 0x0 }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "0", expected); + } + SUBCASE("fft2d, set 1, real") + { + // NOTE: Python test script produced 0x5e7219eb - so off by 1 + std::vector<uint32_t> expected = { 0x5e18358e, 0x5e7219ec, 0x5e2beab2 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "1", expected); + } + SUBCASE("fft2d, set 1, imag") + { + std::vector<uint32_t> expected = { 0x5e71fbcc, 0x5e1bd27a, 0x5e46c84a }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "1", expected); + } + SUBCASE("fft2d, set 2, real") + { + std::vector<uint32_t> expected = { 0x3f800000, 0x3d704bae, 0x3e4443a6 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "2", expected); + } + SUBCASE("fft2d, set 2, imag") + { + std::vector<uint32_t> expected = { 0x3f800000, 0x3dacbd02, 0xbe26be6a }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "2", expected); + } + SUBCASE("fft2d, set 3, real") + { + // NOTE: Python test script produced 0x3de257cf, 0x3f144b53 - so off by 1 + std::vector<uint32_t> expected = { 0x41800000, 0x3de257ce, 0x3f144b54 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "3", expected); + } + SUBCASE("fft2d, set 3, imag") + { + std::vector<uint32_t> expected = { 0x41800000, 0x3f86492c, 0xbf5bd4b3 }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "3", expected); + } + SUBCASE("fft2d, set 4, real") + { + std::vector<uint32_t> expected = { 0x0, 0x5d8c6475, 0x0 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "4", expected); + } + SUBCASE("fft2d, set 4, imag") + { + std::vector<uint32_t> expected = { 0xdca65b4f, 0x5c98b5d2, 0xdd14ddd8 }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "4", expected); + } + SUBCASE("fft2d, set 5, real") + { + std::vector<uint32_t> expected = { 0x5cb9bbd4, 0x5d8c0c21, 0x5daa1928 }; + fft2d_test_FP32(tosaNameReal, tosaElements, templateJsonCfg, "5", expected); + } + SUBCASE("fft2d, set 5, imag") + { + std::vector<uint32_t> expected = { 0x5e708eb3, 0x5e2c1a78, 0x5ddbbc3f }; + fft2d_test_FP32(tosaNameImag, tosaElements, templateJsonCfg, "5", expected); + } +} + TEST_SUITE_END(); // generate diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 6948378..212c809 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -1,5 +1,5 @@ """TOSA result checker script.""" -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import json @@ -55,9 +55,9 @@ def _print_result(color, msg): def compliance_check( - imp_result_path, - ref_result_path, - bnd_result_path, + imp_result_data, + ref_result_data, + bnd_result_data, test_name, compliance_config, ofm_name, @@ -78,14 +78,18 @@ def compliance_check( return (TestResult.INTERNAL_ERROR, 0.0, msg) success = vlib.verify_data( - ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path + ofm_name, compliance_config, imp_result_data, ref_result_data, bnd_result_data ) if success: _print_result(LogColors.GREEN, f"Compliance Results PASS {test_name}") return (TestResult.PASS, 0.0, "") else: _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}") - return (TestResult.MISMATCH, 0.0, "Non-compliance results found") + return ( + TestResult.MISMATCH, + 0.0, + f"Non-compliance results found for {ofm_name}", + ) def test_check( diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 5e35e8b..067fab7 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -980,6 +980,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "generator_args": [ @@ -987,13 +988,13 @@ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0" + "-max,max" ], [ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--target-shape", "1,256,64", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b4939da..f6a46b4 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2798,9 +2798,27 @@ class TosaArgGen: def agFFT2d(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] - arg_list.append(("inverseTrue", [True])) - arg_list.append(("inverseFalse", [False])) + shape = shapeList[0] + dot_products = gtu.product(shape) + ks = 2 * shape[1] * shape[2] # 2*H*W + for inverse in (True, False): + args_dict = { + "dot_products": dot_products, + "shape": shape, + "ks": ks, + "acc_type": dtype, + "inverse": inverse, + } + arg_list.append((f"inverse{inverse}", args_dict)) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list # Helper function for reshape. Gets some factors of a larger number. diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index bfafd23..68a4e94 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -381,8 +381,35 @@ class TosaTestGen: """Enhanced build information containing result tensor and associated compliance dict.""" def __init__(self, resultTensor, complianceDict): - self.resultTensor = resultTensor - self.complianceDict = complianceDict + if isinstance(resultTensor, list): + assert complianceDict is None or isinstance(complianceDict, list) + self.resultTensorList = resultTensor + self.complianceDictList = complianceDict + else: + self.resultTensorList = [resultTensor] + if complianceDict is None: + self.complianceDictList = None + else: + self.complianceDictList = [complianceDict] + + def getComplianceInfo(self): + if self.complianceDictList is None: + return None + else: + tens_dict = {} + for tens, comp in zip(self.resultTensorList, self.complianceDictList): + if comp is not None: + tens_dict[tens.name] = comp + + if tens_dict: + # Have some compliance data, so return the info + compliance = { + "version": "0.1", + "tensors": tens_dict, + } + else: + compliance = None + return compliance def build_unary( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -2491,12 +2518,16 @@ class TosaTestGen: def build_fft2d( self, op, - val1, - val2, - inverse, + inputs, + args_dict, validator_fcns=None, error_name=None, + qinfo=None, ): + assert len(inputs) == 2 + val1, val2 = inputs + inverse = args_dict["inverse"] + results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) input_names = [val1.name, val2.name] @@ -2537,7 +2568,16 @@ class TosaTestGen: attr.FFTAttribute(inverse, local_bound) self.ser.addOperator(op["op"], input_names, output_names, attr) - return results + + compliance = [] + for res in results: + compliance.append( + self.tensorComplianceMetaData( + op, val1.dtype, args_dict, res, error_name + ) + ) + + return TosaTestGen.BuildInfo(results, compliance) def build_rfft2d( self, @@ -2933,13 +2973,11 @@ class TosaTestGen: if result: # The test is valid, serialize it - if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict: - # Add the compliance meta data - # NOTE: This currently expects only one result output - tensMeta["compliance"] = { - "version": "0.1", - "tensors": {result.resultTensor.name: result.complianceDict}, - } + if isinstance(result, TosaTestGen.BuildInfo): + # Add the compliance meta data (if any) + compliance = result.getComplianceInfo() + if compliance: + tensMeta["compliance"] = compliance self.serialize("test", tensMeta) else: # The test is not valid @@ -4708,7 +4746,7 @@ class TosaTestGen: "build_fcn": ( build_fft2d, TosaTensorGen.tgFFT2d, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agFFT2d, ), "types": [DType.FP32], @@ -4723,6 +4761,9 @@ class TosaTestGen: TosaErrorValidator.evFFTInputShapeMismatch, TosaErrorValidator.evFFTOutputShapeMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, }, "rfft2d": { "op": Op.RFFT2D, |