aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-04 14:17:55 +0100
committerEric Kunze <eric.kunze@arm.com>2023-10-04 18:45:40 +0000
commitb20b0c9cb4c85bb9a3c901d5acaf421d84656850 (patch)
tree8af9d6338b62bc65e7e4292427f06a4ef0346312
parent12ee1a79374b451602784fd6dc8f63886bf2a997 (diff)
downloadreference_model-b20b0c9cb4c85bb9a3c901d5acaf421d84656850.tar.gz
Add initial TOSA MI generator support
Add support for dot-product MatMul - test set 0 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ifd15b42570014b634f59c94a1fd1cd56bac79ea4 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--reference_model/CMakeLists.txt21
-rw-r--r--reference_model/include/generate.h45
-rw-r--r--reference_model/include/verify.h6
-rw-r--r--reference_model/src/generate/generate_dot_product.cc100
-rw-r--r--reference_model/src/generate/generate_dot_product.h46
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc119
-rw-r--r--reference_model/src/generate/generate_entry.cc75
-rw-r--r--reference_model/src/generate/generate_utils.cc137
-rw-r--r--reference_model/src/generate/generate_utils.h80
-rw-r--r--reference_model/test/generate_tests.cpp115
10 files changed, 743 insertions, 1 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index cc2a5e3..94e612d 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -71,6 +71,10 @@ set(CXX_SOURCE
src/operators.cc
src/subgraph_traverser.cc
src/tensor.cc
+ src/generate/generate_dot_product_states.cc
+ src/generate/generate_dot_product.cc
+ src/generate/generate_entry.cc
+ src/generate/generate_utils.cc
src/verify/verify_dot_product.cc
src/verify/verify_entry.cc
src/verify/verify_exact.cc
@@ -130,6 +134,7 @@ list(APPEND PUBLIC_HEADERS
include/dtype.h
include/func_config.h
include/func_debug.h
+ include/generate.h
include/graph_status.h
include/model_common.h
include/model_runner.h
@@ -158,6 +163,21 @@ target_include_directories(tosa_reference_verify_lib
${PRIVATE_INCLUDE_DIRS}
)
+# Build TOSA generator library
+add_library(tosa_reference_generate_lib SHARED
+ src/generate/generate_dot_product_states.cc
+ src/generate/generate_dot_product.cc
+ src/generate/generate_entry.cc
+ src/generate/generate_utils.cc
+ src/func_debug.cc
+)
+target_include_directories(tosa_reference_generate_lib
+ PUBLIC
+ ${PUBLIC_INCLUDE_DIRS}
+ PRIVATE
+ ${PRIVATE_INCLUDE_DIRS}
+)
+
# Build TOSA Refererence Model executable
if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
set(CXX_SOURCE_EX src/main.cpp)
@@ -193,6 +213,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
# Sources only required for unit tests.
set(CXX_SOURCE_TESTS
+ test/generate_tests.cpp
test/model_runner_tests.cpp
test/verify_tests.cpp
${DOCTEST_DIR}/doctest.h
diff --git a/reference_model/include/generate.h b/reference_model/include/generate.h
new file mode 100644
index 0000000..32562a0
--- /dev/null
+++ b/reference_model/include/generate.h
@@ -0,0 +1,45 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//===----------------------------------------------------------------------===//
+//
+// Data generation functionality as per TOSA Specification (5.2)
+//
+//===----------------------------------------------------------------------===//
+#ifndef GENERATE_H
+#define GENERATE_H
+
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif /* __cplusplus */
+
+ /// \brief Perform input data generation for a given tensor
+ ///
+ /// A configuration provides context about the type of generator to be used e.g. Pseudo-random
+ /// alongside with information on the operator and the slot that the tensor is consumed by.
+ ///
+ /// \param config_json JSON configuration of the tensor that we need to generate data for
+ /// \param tensor_name Name of the tensor to extract generator information
+ /// \param data User-provided buffer to store the data to
+ /// \param size Size of the provided buffer in bytes
+ /// \return
+ bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size);
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // GENERATE_H \ No newline at end of file
diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h
index e449ff7..36e1d7b 100644
--- a/reference_model/include/verify.h
+++ b/reference_model/include/verify.h
@@ -17,6 +17,8 @@
// Output Verification : Section 1.8.2
//
//===----------------------------------------------------------------------===//
+#ifndef VERIFY_H
+#define VERIFY_H
#include "types.h"
@@ -44,4 +46,6 @@ extern "C"
#ifdef __cplusplus
}
-#endif /* __cplusplus */ \ No newline at end of file
+#endif /* __cplusplus */
+
+#endif // VERIFY_H \ No newline at end of file
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
new file mode 100644
index 0000000..90710ba
--- /dev/null
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -0,0 +1,100 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate_dot_product.h"
+
+namespace
+{
+//---------------------------------------------------------------------------//
+// MatMul //
+//---------------------------------------------------------------------------//
+
+void generateMatMulA(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ float* a = reinterpret_cast<float*>(data);
+ const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
+ const uint32_t C = cfg.shape[2];
+
+ for (uint32_t t = 0; t < T; ++t)
+ {
+ a[t] = generator(t % C); // k = c
+ }
+}
+
+void generateMatMulB(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ float* b = reinterpret_cast<float*>(data);
+ const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2];
+ const uint32_t C = cfg.shape[1];
+ const uint32_t W = cfg.shape[2];
+
+ for (uint32_t t = 0; t < T; ++t)
+ {
+ b[t] = generator((t / W) % C); // k = c
+ }
+}
+
+bool generateMatMul(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 3)
+ {
+ WARNING("[Generator][DP][MatMul] Tensor shape expected 3 dimensions.");
+ return false;
+ }
+ if (cfg.inputPos > 1 || cfg.inputPos < 0)
+ {
+ WARNING("[Generator][DP][MatMul] Invalid input tensor slot position to operator.");
+ return false;
+ }
+
+ (cfg.inputPos == 0) ? generateMatMulA(cfg, generator, data, size) : generateMatMulB(cfg, generator, data, size);
+
+ return true;
+}
+} // namespace
+
+namespace TosaReference
+{
+
+bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
+{
+ auto generator = pickDotProductGenerator(cfg);
+ if (!generator)
+ {
+ WARNING("[Generator][DP] Requested generator could not be created!");
+ return 0;
+ }
+
+ // Select which generator to use
+ switch (cfg.opType)
+ {
+ case tosa::Op_MATMUL:
+ return generateMatMul(cfg, *generator, data, size);
+ default:
+ WARNING("[Generator][DP] Unsupported operator");
+ return false;
+ }
+
+ return false;
+}
+} // namespace TosaReference \ No newline at end of file
diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h
new file mode 100644
index 0000000..3d4ecc6
--- /dev/null
+++ b/reference_model/src/generate/generate_dot_product.h
@@ -0,0 +1,46 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_DOT_PRODUCT_H_
+#define GENERATE_DOT_PRODUCT_H_
+
+#include "generate_utils.h"
+
+#include <memory>
+
+namespace TosaReference
+{
+
+/// \brief Generic dot-product generator interface
+class IDotProductGenerator
+{
+public:
+ virtual float operator()(uint32_t k) = 0;
+};
+
+/// \brief Dot-product stage generator selector
+std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg);
+
+/// \brief Perform dot-product based generation
+///
+/// \param cfg Generator related meta-data
+/// \param data Buffer to generate the data to
+/// \param size Size of the buffet
+///
+/// \return True on successful generation
+bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_DOT_PRODUCT_H_
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
new file mode 100644
index 0000000..cd9ffba
--- /dev/null
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -0,0 +1,119 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate_dot_product.h"
+#include "generate_utils.h"
+
+#include <cstdint>
+
+namespace
+{
+
+// Input index global variables
+inline constexpr uint32_t P0 = 0;
+inline constexpr uint32_t P1 = 1;
+
+// Unused helper function
+template <typename... Args>
+inline void unused(Args&&...)
+{}
+
+// Primitive generator class
+//
+// Yields a new value on function operator access and increases the
+// index by one
+class PrimitiveGenerator
+{
+public:
+ PrimitiveGenerator(uint32_t S)
+ : _S(S)
+ , _m(0)
+ , _r(0)
+ , _index(0)
+ {
+ _m = (8 * _S + 1) * 0x705A5E75;
+ _r = _m + 1;
+ }
+
+ [[nodiscard]] float operator()()
+ {
+ _r = _r * _m + 1;
+ float sign = (_r >> 31) == 0 ? +1 : -1;
+ float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
+ ++_index;
+
+ return pseudo;
+ }
+
+ uint32_t index()
+ {
+ return _index;
+ }
+
+private:
+ uint32_t _S;
+ uint32_t _m;
+ uint32_t _r;
+ uint32_t _index;
+};
+
+//----------------------------------------------------------------------------//
+// State generators
+//----------------------------------------------------------------------------//
+
+// S0 generator
+class GeneratorS0 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS0(uint32_t p)
+ : _p(p)
+ , _s0(0) // set_data(2*S)
+ , _s1(1) // set_data(2*S+1)
+ {}
+ float operator()(uint32_t k) override
+ {
+ unused(k);
+ const float s0 = _s0();
+ const float s1 = _s1();
+ if (_p == P0)
+ return s0 < 0.f ? 0.f : s1;
+ else
+ return s0 < 0.f ? s1 : 0.f;
+ }
+
+private:
+ uint32_t _p;
+ PrimitiveGenerator _s0;
+ PrimitiveGenerator _s1;
+};
+
+} // namespace
+
+namespace TosaReference
+{
+
+std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
+{
+ const DotProductInfo& dpinfo = cfg.dotProductInfo;
+ switch (dpinfo.s)
+ {
+ case 0:
+ return std::make_unique<GeneratorS0>(cfg.inputPos);
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+} // namespace TosaReference \ No newline at end of file
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
new file mode 100644
index 0000000..95dbe8f
--- /dev/null
+++ b/reference_model/src/generate/generate_entry.cc
@@ -0,0 +1,75 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate.h"
+
+#include "generate_dot_product.h"
+#include "generate_utils.h"
+
+#include "func_debug.h"
+#include "model_common.h"
+
+namespace TosaReference
+{
+
+bool generate(const GenerateConfig& cfg, void* data, size_t size)
+{
+ switch (cfg.generatorType)
+ {
+ case GeneratorType::DotProduct: {
+ return generateDotProduct(cfg, data, size);
+ break;
+ }
+ default: {
+ WARNING("[Generator] Unsupported generation mode.");
+ break;
+ }
+ }
+ return false;
+}
+
+} // namespace TosaReference
+
+extern "C"
+{
+ bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size)
+ {
+ // Check inputs for nullptr
+ if (!config_json || !tensor_name || !data)
+ {
+ WARNING("[Generator] One of the inputs is missing.");
+ return false;
+ }
+
+ // Check JSON config validity
+ auto cfg = TosaReference::parseGenerateConfig(config_json, tensor_name);
+ if (!cfg)
+ {
+ WARNING("[Generator] Invalid json config.");
+ return false;
+ }
+
+ // Check size
+ const size_t totalBytesNeeded =
+ TosaReference::numElementsFromShape(cfg->shape) * TosaReference::elementSizeFromType(cfg->dataType);
+ if (totalBytesNeeded > size)
+ {
+ WARNING("[Generator] Not enough space in provided buffer.");
+ return false;
+ }
+
+ // Run generator
+ return generate(cfg.value(), data, size);
+ }
+} // extern "C"
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
new file mode 100644
index 0000000..c52f051
--- /dev/null
+++ b/reference_model/src/generate/generate_utils.cc
@@ -0,0 +1,137 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate_utils.h"
+
+#include <nlohmann/json.hpp>
+
+#include <algorithm>
+
+namespace tosa
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(DType,
+ {
+ { DType::DType_BOOL, "BOOL" },
+ { DType::DType_INT4, "INT4" },
+ { DType::DType_INT8, "INT8" },
+ { DType::DType_INT16, "INT16" },
+ { DType::DType_INT32, "INT32" },
+ { DType::DType_INT48, "INT48" },
+ { DType::DType_FP16, "FP16" },
+ { DType::DType_BF16, "BF16" },
+ { DType::DType_FP32, "FP32" },
+ })
+
+NLOHMANN_JSON_SERIALIZE_ENUM(Op,
+ {
+ { Op::Op_MATMUL, "MATMUL" },
+ })
+
+} // namespace tosa
+
+namespace TosaReference
+{
+
+NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
+ {
+ { GeneratorType::PseudoRandom, "PSEUDO_RANDOM" },
+ { GeneratorType::DotProduct, "DOT_PRODUCT" },
+ { GeneratorType::OpFullRange, "OP_FULL_RANGE" },
+ { GeneratorType::OpBoundary, "OP_BOUNDARY" },
+ { GeneratorType::OpSpecial, "OP_SPECIAL" },
+ })
+
+NLOHMANN_JSON_SERIALIZE_ENUM(InputType,
+ {
+ { InputType::Variable, "VARIABLE" },
+ { InputType::Constant, "CONSTANT" },
+ })
+
+void from_json(const nlohmann::json& j, DotProductInfo& dotProductInfo)
+{
+ j.at("s").get_to(dotProductInfo.s);
+ j.at("ks").get_to(dotProductInfo.ks);
+ j.at("acc_type").get_to(dotProductInfo.accType);
+ if (j.contains("kernel"))
+ {
+ j.at("kernel").get_to(dotProductInfo.kernel);
+ }
+ if (j.contains("axis"))
+ {
+ j.at("axis").get_to(dotProductInfo.axis);
+ }
+}
+
+void from_json(const nlohmann::json& j, GenerateConfig& cfg)
+{
+ j.at("data_type").get_to(cfg.dataType);
+ j.at("input_type").get_to(cfg.inputType);
+ j.at("shape").get_to(cfg.shape);
+ j.at("input_pos").get_to(cfg.inputPos);
+ j.at("op").get_to(cfg.opType);
+ j.at("generator").get_to(cfg.generatorType);
+ if (j.contains("dot_product_info"))
+ {
+ j.at("dot_product_info").get_to(cfg.dotProductInfo);
+ }
+}
+
+std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
+{
+ if (!tensorName)
+ return std::nullopt;
+
+ auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false);
+
+ if (jsonCfg.is_discarded())
+ return std::nullopt;
+ if (!jsonCfg.contains("tensors"))
+ return std::nullopt;
+
+ const auto& tensors = jsonCfg["tensors"];
+ if (!tensors.contains(tensorName))
+ return std::nullopt;
+
+ const auto& namedTensor = tensors[tensorName];
+ return namedTensor.get<GenerateConfig>();
+}
+
+int64_t numElementsFromShape(const std::vector<int32_t>& shape)
+{
+ return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
+}
+
+size_t elementSizeFromType(DType type)
+{
+ switch (type)
+ {
+ case DType::DType_BOOL:
+ case DType::DType_UINT8:
+ case DType::DType_INT8:
+ return 1;
+ case DType::DType_UINT16:
+ case DType::DType_INT16:
+ case DType::DType_FP16:
+ case DType::DType_BF16:
+ return 2;
+ case DType::DType_INT32:
+ case DType::DType_FP32:
+ return 4;
+ default:
+ return 0;
+ }
+ return 0;
+}
+} // namespace TosaReference
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
new file mode 100644
index 0000000..2d5b7f8
--- /dev/null
+++ b/reference_model/src/generate/generate_utils.h
@@ -0,0 +1,80 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_UTILS_H_
+#define GENERATE_UTILS_H_
+
+#include "dtype.h"
+
+#include <array>
+#include <cstdint>
+#include <optional>
+#include <vector>
+
+namespace TosaReference
+{
+
+/// \brief Supported generator types
+enum class GeneratorType
+{
+ PseudoRandom,
+ DotProduct,
+ OpFullRange,
+ OpBoundary,
+ OpSpecial,
+};
+
+/// \brief Supported input types
+enum class InputType
+{
+ Variable,
+ Constant,
+};
+
+/// \brief Dot-product generator meta-data
+struct DotProductInfo
+{
+ DotProductInfo() = default;
+
+ int32_t s;
+ int32_t ks;
+ DType accType;
+ int32_t axis;
+ std::array<int32_t, 2> kernel;
+};
+
+/// \brief Generator configuration
+struct GenerateConfig
+{
+ GeneratorType generatorType;
+ DType dataType;
+ InputType inputType;
+ std::vector<int32_t> shape;
+ int32_t inputPos;
+ tosa::Op opType;
+ DotProductInfo dotProductInfo;
+};
+
+/// \brief Parse the generator config when given in JSON form
+std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName);
+
+/// \brief Extract number of total elements
+int64_t numElementsFromShape(const std::vector<int32_t>& shape);
+
+/// \brief Size in bytes of a given type
+size_t elementSizeFromType(DType type);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_UTILS_H_
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
new file mode 100644
index 0000000..88dc979
--- /dev/null
+++ b/reference_model/test/generate_tests.cpp
@@ -0,0 +1,115 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "generate.h"
+
+#include <doctest.h>
+
+#include <array>
+#include <string>
+#include <vector>
+
+TEST_SUITE_BEGIN("generate");
+
+TEST_CASE("negative - api")
+{
+ std::string json_cfg = R"({
+ "tensors" : {
+ "in1" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 4, 8, 8 ],
+ "input_pos": 0,
+ "op" : "MATMUL",
+ "dot_product_info": {
+ "s": 0,
+ "ks": 10,
+ "acc_type": "FP32"
+ }
+ }
+ }
+ })";
+
+ const std::string tosaName = "in1";
+ const size_t tosaElements = 4 * 8 * 8;
+ const size_t tosaSize = tosaElements * 4;
+
+ SUBCASE("missing input")
+ {
+ REQUIRE_FALSE(tgd_generate_data(NULL, NULL, NULL, 0));
+ }
+ SUBCASE("invalid json")
+ {
+ std::string invalid_json_cfg = R"({
+ "tensors" : {
+ "in1" : {
+ "generator": DOT_PRODUCT,
+ },
+ }
+ })";
+
+ std::vector<float> buffer(tosaElements);
+ REQUIRE_FALSE(tgd_generate_data(invalid_json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
+ }
+ SUBCASE("invalid json - mismatching name")
+ {
+ std::string invalidName = "notFound1";
+
+ std::vector<float> buffer(tosaElements);
+ REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), invalidName.c_str(), (void*)buffer.data(), tosaSize));
+ }
+ SUBCASE("mismatching size")
+ {
+ size_t smallElements = 4 * 8 * 7;
+ size_t smallSize = smallElements * 4;
+
+ std::vector<float> buffer(smallElements);
+ REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), smallSize));
+ }
+}
+
+TEST_CASE("positive - dot product")
+{
+ std::string json_cfg = R"({
+ "tensors" : {
+ "in1" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 4, 8, 8 ],
+ "input_pos": 0,
+ "op" : "MATMUL",
+ "dot_product_info": {
+ "s": 0,
+ "ks": 10,
+ "acc_type": "FP32"
+ }
+ }
+ }
+ })";
+
+ const std::string tosaName = "in1";
+ const size_t tosaElements = 4 * 8 * 8;
+ const size_t tosaSize = tosaElements * 4;
+
+ SUBCASE("matmul")
+ {
+ std::vector<float> buffer(tosaElements);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
+ REQUIRE(buffer[0] == (float)-0.950864);
+ REQUIRE(buffer[1] == 0.f);
+ }
+}
+
+TEST_SUITE_END(); // generate \ No newline at end of file