diff options
Diffstat (limited to 'reference_model/test/verify_tests.cpp')
-rw-r--r-- | reference_model/test/verify_tests.cpp | 122 |
1 files changed, 92 insertions, 30 deletions
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp index f43804f..81f3e8d 100644 --- a/reference_model/test/verify_tests.cpp +++ b/reference_model/test/verify_tests.cpp @@ -16,41 +16,103 @@ #include <doctest.h> #include <array> -#include <numeric> +#include <string> #include <vector> -TEST_SUITE("verify") +namespace { - TEST_CASE("check_element_accfp32") + +class TosaTensor +{ +public: + TosaTensor(std::string name, tosa_datatype_t dataType, std::vector<int32_t> shape) + : _name(std::move(name)) + , _shape(std::move(shape)) + { + _tensor.name = _name.c_str(); + _tensor.data_type = dataType; + _tensor.num_dims = _shape.size(); + _tensor.shape = _shape.data(); + }; + + const tosa_tensor_t* cTensor() const { - const size_t KS = 27; - - // Negative (bnd == 0.0) - REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid); - REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid); - // Negative (bnd > 0.0) - REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid); - - // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0) - REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid); - REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0); - - // Positive (bnd > 0.0) - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); + return &_tensor; } - TEST_CASE("check_output_error") + +private: + std::string _name; + std::vector<int32_t> _shape; + tosa_tensor_t _tensor; +}; + +} // namespace + +TEST_SUITE_BEGIN("verify"); + +TEST_CASE("negative - api") +{ + std::string json_cfg = R"({ + "tensors" : { + "out1" : { + "mode": "DOT_PRODUCT", + "dot_product_info" : { + "data_type": "FP32", + "s": 2, + "ks": 9 + } + } + } + })"; + + SUBCASE("invalid json") { - const size_t KS = 27; - const size_t T = 1024; - - // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T))) - REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0)); - // Negative (err_sum_sq > 0.4*KS*T)) - REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1)); - // Positive - REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0)); - REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1)); + std::string invalid_json_cfg = R"({ + "tensors" : { + "out1" : { + "mode": DOT_PRODUCT, + }, + } + })"; + + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), invalid_json_cfg.c_str())); + } + SUBCASE("mismatching dimensions") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 4, 4 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 4, 4 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("mismatching shapes") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 4, 4, 4 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("mismatching data types") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp16_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("missing tensor data") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); } } + +TEST_SUITE_END(); // verify
\ No newline at end of file |