diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2023-08-22 08:25:57 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-09-07 16:03:50 +0000 |
commit | 7021ef064f7daeca260bb1f1bd61b5bbc6473aa5 (patch) | |
tree | 24a488954ab0a7c6e29e811429ad194af67c3880 /reference_model/test | |
parent | 391cc5e80559e46081b6aa12c344d820dc685c95 (diff) | |
download | reference_model-7021ef064f7daeca260bb1f1bd61b5bbc6473aa5.tar.gz |
Rework TOSA verification API
Change verifier API to consume verification configuration in a JSON
format and enable appropriate validation to be performed within the
verifier code in the reference model.
Also update to latest spec changes for main compliance but not yet
including bias support.
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I0ceaa1714dd041b00b5b29cd937c8f05e701bc4c
Diffstat (limited to 'reference_model/test')
-rw-r--r-- | reference_model/test/verify_tests.cpp | 122 |
1 files changed, 92 insertions, 30 deletions
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp index f43804f..81f3e8d 100644 --- a/reference_model/test/verify_tests.cpp +++ b/reference_model/test/verify_tests.cpp @@ -16,41 +16,103 @@ #include <doctest.h> #include <array> -#include <numeric> +#include <string> #include <vector> -TEST_SUITE("verify") +namespace { - TEST_CASE("check_element_accfp32") + +class TosaTensor +{ +public: + TosaTensor(std::string name, tosa_datatype_t dataType, std::vector<int32_t> shape) + : _name(std::move(name)) + , _shape(std::move(shape)) + { + _tensor.name = _name.c_str(); + _tensor.data_type = dataType; + _tensor.num_dims = _shape.size(); + _tensor.shape = _shape.data(); + }; + + const tosa_tensor_t* cTensor() const { - const size_t KS = 27; - - // Negative (bnd == 0.0) - REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid); - REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid); - // Negative (bnd > 0.0) - REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid); - - // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0) - REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid); - REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0); - - // Positive (bnd > 0.0) - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); - REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0); + return &_tensor; } - TEST_CASE("check_output_error") + +private: + std::string _name; + std::vector<int32_t> _shape; + tosa_tensor_t _tensor; +}; + +} // namespace + +TEST_SUITE_BEGIN("verify"); + +TEST_CASE("negative - api") +{ + std::string json_cfg = R"({ + "tensors" : { + "out1" : { + "mode": "DOT_PRODUCT", + "dot_product_info" : { + "data_type": "FP32", + "s": 2, + "ks": 9 + } + } + } + })"; + + SUBCASE("invalid json") { - const size_t KS = 27; - const size_t T = 1024; - - // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T))) - REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0)); - // Negative (err_sum_sq > 0.4*KS*T)) - REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1)); - // Positive - REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0)); - REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1)); + std::string invalid_json_cfg = R"({ + "tensors" : { + "out1" : { + "mode": DOT_PRODUCT, + }, + } + })"; + + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), invalid_json_cfg.c_str())); + } + SUBCASE("mismatching dimensions") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 4, 4 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 4, 4 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("mismatching shapes") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 4, 4, 4 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("mismatching data types") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp16_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); + } + SUBCASE("missing tensor data") + { + const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 }); + const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 }); + + REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str())); } } + +TEST_SUITE_END(); // verify
\ No newline at end of file |