aboutsummaryrefslogtreecommitdiff
path: root/reference_model/test
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2023-08-22 08:25:57 +0100
committerEric Kunze <eric.kunze@arm.com>2023-09-07 16:03:50 +0000
commit7021ef064f7daeca260bb1f1bd61b5bbc6473aa5 (patch)
tree24a488954ab0a7c6e29e811429ad194af67c3880 /reference_model/test
parent391cc5e80559e46081b6aa12c344d820dc685c95 (diff)
downloadreference_model-7021ef064f7daeca260bb1f1bd61b5bbc6473aa5.tar.gz
Rework TOSA verification API
Change verifier API to consume verification configuration in a JSON format and enable appropriate validation to be performed within the verifier code in the reference model. Also update to latest spec changes for main compliance but not yet including bias support. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ceaa1714dd041b00b5b29cd937c8f05e701bc4c
Diffstat (limited to 'reference_model/test')
-rw-r--r--reference_model/test/verify_tests.cpp122
1 files changed, 92 insertions, 30 deletions
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
index f43804f..81f3e8d 100644
--- a/reference_model/test/verify_tests.cpp
+++ b/reference_model/test/verify_tests.cpp
@@ -16,41 +16,103 @@
#include <doctest.h>
#include <array>
-#include <numeric>
+#include <string>
#include <vector>
-TEST_SUITE("verify")
+namespace
{
- TEST_CASE("check_element_accfp32")
+
+class TosaTensor
+{
+public:
+ TosaTensor(std::string name, tosa_datatype_t dataType, std::vector<int32_t> shape)
+ : _name(std::move(name))
+ , _shape(std::move(shape))
+ {
+ _tensor.name = _name.c_str();
+ _tensor.data_type = dataType;
+ _tensor.num_dims = _shape.size();
+ _tensor.shape = _shape.data();
+ };
+
+ const tosa_tensor_t* cTensor() const
{
- const size_t KS = 27;
-
- // Negative (bnd == 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(0.0, 0.0, 1.0, KS).is_valid);
- REQUIRE_FALSE(tosa_validate_element_accfp32(1.0, 0.0, 0.0, KS).is_valid);
- // Negative (bnd > 0.0)
- REQUIRE_FALSE(tosa_validate_element_accfp32(5.0, 5.0, 5.1, KS).is_valid);
-
- // Positive (bnd == 0.0 && ref == 0.0 && imp == 0.0)
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).is_valid);
- REQUIRE(tosa_validate_element_accfp32(0.0, 0.0, 0.0, KS).error == 0.0);
-
- // Positive (bnd > 0.0)
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
- REQUIRE(tosa_validate_element_accfp32(4.0, 4.0, 4.0, KS).error == 0.0);
+ return &_tensor;
}
- TEST_CASE("check_output_error")
+
+private:
+ std::string _name;
+ std::vector<int32_t> _shape;
+ tosa_tensor_t _tensor;
+};
+
+} // namespace
+
+TEST_SUITE_BEGIN("verify");
+
+TEST_CASE("negative - api")
+{
+ std::string json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": "DOT_PRODUCT",
+ "dot_product_info" : {
+ "data_type": "FP32",
+ "s": 2,
+ "ks": 9
+ }
+ }
+ }
+ })";
+
+ SUBCASE("invalid json")
{
- const size_t KS = 27;
- const size_t T = 1024;
-
- // Negative (S!=1 && S!=2 && (abs(err_sum) > 2*sqrt(KS*T)))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 0));
- // Negative (err_sum_sq > 0.4*KS*T))
- REQUIRE_FALSE(tosa_validate_output_error(1560, 112000, KS, T, 1));
- // Positive
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 0));
- REQUIRE(tosa_validate_output_error(10, 254, KS, T, 1));
+ std::string invalid_json_cfg = R"({
+ "tensors" : {
+ "out1" : {
+ "mode": DOT_PRODUCT,
+ },
+ }
+ })";
+
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), invalid_json_cfg.c_str()));
+ }
+ SUBCASE("mismatching dimensions")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 4, 4 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching shapes")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 4, 4, 4 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("mismatching data types")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp16_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
+ }
+ SUBCASE("missing tensor data")
+ {
+ const TosaTensor ref("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor refAbs("out1", tosa_datatype_fp64_t, { 8, 8, 8 });
+ const TosaTensor imp("out1", tosa_datatype_fp32_t, { 8, 8, 8 });
+
+ REQUIRE_FALSE(tvf_verify_data(ref.cTensor(), refAbs.cTensor(), imp.cTensor(), json_cfg.c_str()));
}
}
+
+TEST_SUITE_END(); // verify \ No newline at end of file