diff options
-rw-r--r-- | reference_model/CMakeLists.txt | 21 | ||||
-rw-r--r-- | reference_model/include/generate.h | 45 | ||||
-rw-r--r-- | reference_model/include/verify.h | 6 | ||||
-rw-r--r-- | reference_model/src/generate/generate_dot_product.cc | 100 | ||||
-rw-r--r-- | reference_model/src/generate/generate_dot_product.h | 46 | ||||
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 119 | ||||
-rw-r--r-- | reference_model/src/generate/generate_entry.cc | 75 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 137 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.h | 80 | ||||
-rw-r--r-- | reference_model/test/generate_tests.cpp | 115 |
10 files changed, 743 insertions, 1 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index cc2a5e3..94e612d 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -71,6 +71,10 @@ set(CXX_SOURCE src/operators.cc src/subgraph_traverser.cc src/tensor.cc + src/generate/generate_dot_product_states.cc + src/generate/generate_dot_product.cc + src/generate/generate_entry.cc + src/generate/generate_utils.cc src/verify/verify_dot_product.cc src/verify/verify_entry.cc src/verify/verify_exact.cc @@ -130,6 +134,7 @@ list(APPEND PUBLIC_HEADERS include/dtype.h include/func_config.h include/func_debug.h + include/generate.h include/graph_status.h include/model_common.h include/model_runner.h @@ -158,6 +163,21 @@ target_include_directories(tosa_reference_verify_lib ${PRIVATE_INCLUDE_DIRS} ) +# Build TOSA generator library +add_library(tosa_reference_generate_lib SHARED + src/generate/generate_dot_product_states.cc + src/generate/generate_dot_product.cc + src/generate/generate_entry.cc + src/generate/generate_utils.cc + src/func_debug.cc +) +target_include_directories(tosa_reference_generate_lib + PUBLIC + ${PUBLIC_INCLUDE_DIRS} + PRIVATE + ${PRIVATE_INCLUDE_DIRS} +) + # Build TOSA Refererence Model executable if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) set(CXX_SOURCE_EX src/main.cpp) @@ -193,6 +213,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS) # Sources only required for unit tests. set(CXX_SOURCE_TESTS + test/generate_tests.cpp test/model_runner_tests.cpp test/verify_tests.cpp ${DOCTEST_DIR}/doctest.h diff --git a/reference_model/include/generate.h b/reference_model/include/generate.h new file mode 100644 index 0000000..32562a0 --- /dev/null +++ b/reference_model/include/generate.h @@ -0,0 +1,45 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// +// +// Data generation functionality as per TOSA Specification (5.2) +// +//===----------------------------------------------------------------------===// +#ifndef GENERATE_H +#define GENERATE_H + +#include <stddef.h> + +#ifdef __cplusplus +extern "C" +{ +#endif /* __cplusplus */ + + /// \brief Perform input data generation for a given tensor + /// + /// A configuration provides context about the type of generator to be used e.g. Pseudo-random + /// alongside with information on the operator and the slot that the tensor is consumed by. + /// + /// \param config_json JSON configuration of the tensor that we need to generate data for + /// \param tensor_name Name of the tensor to extract generator information + /// \param data User-provided buffer to store the data to + /// \param size Size of the provided buffer in bytes + /// \return + bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif // GENERATE_H
\ No newline at end of file diff --git a/reference_model/include/verify.h b/reference_model/include/verify.h index e449ff7..36e1d7b 100644 --- a/reference_model/include/verify.h +++ b/reference_model/include/verify.h @@ -17,6 +17,8 @@ // Output Verification : Section 1.8.2 // //===----------------------------------------------------------------------===// +#ifndef VERIFY_H +#define VERIFY_H #include "types.h" @@ -44,4 +46,6 @@ extern "C" #ifdef __cplusplus } -#endif /* __cplusplus */
\ No newline at end of file +#endif /* __cplusplus */ + +#endif // VERIFY_H
\ No newline at end of file diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc new file mode 100644 index 0000000..90710ba --- /dev/null +++ b/reference_model/src/generate/generate_dot_product.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_dot_product.h" + +namespace +{ +//---------------------------------------------------------------------------// +// MatMul // +//---------------------------------------------------------------------------// + +void generateMatMulA(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + float* a = reinterpret_cast<float*>(data); + const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2]; + const uint32_t C = cfg.shape[2]; + + for (uint32_t t = 0; t < T; ++t) + { + a[t] = generator(t % C); // k = c + } +} + +void generateMatMulB(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + float* b = reinterpret_cast<float*>(data); + const uint32_t T = cfg.shape[0] * cfg.shape[1] * cfg.shape[2]; + const uint32_t C = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + for (uint32_t t = 0; t < T; ++t) + { + b[t] = generator((t / W) % C); // k = c + } +} + +bool generateMatMul(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + if (cfg.shape.size() != 3) + { + WARNING("[Generator][DP][MatMul] Tensor shape expected 3 dimensions."); + return false; + } + if (cfg.inputPos > 1 || cfg.inputPos < 0) + { + WARNING("[Generator][DP][MatMul] Invalid input tensor slot position to operator."); + return false; + } + + (cfg.inputPos == 0) ? generateMatMulA(cfg, generator, data, size) : generateMatMulB(cfg, generator, data, size); + + return true; +} +} // namespace + +namespace TosaReference +{ + +bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size) +{ + auto generator = pickDotProductGenerator(cfg); + if (!generator) + { + WARNING("[Generator][DP] Requested generator could not be created!"); + return 0; + } + + // Select which generator to use + switch (cfg.opType) + { + case tosa::Op_MATMUL: + return generateMatMul(cfg, *generator, data, size); + default: + WARNING("[Generator][DP] Unsupported operator"); + return false; + } + + return false; +} +} // namespace TosaReference
\ No newline at end of file diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h new file mode 100644 index 0000000..3d4ecc6 --- /dev/null +++ b/reference_model/src/generate/generate_dot_product.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERATE_DOT_PRODUCT_H_ +#define GENERATE_DOT_PRODUCT_H_ + +#include "generate_utils.h" + +#include <memory> + +namespace TosaReference +{ + +/// \brief Generic dot-product generator interface +class IDotProductGenerator +{ +public: + virtual float operator()(uint32_t k) = 0; +}; + +/// \brief Dot-product stage generator selector +std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg); + +/// \brief Perform dot-product based generation +/// +/// \param cfg Generator related meta-data +/// \param data Buffer to generate the data to +/// \param size Size of the buffet +/// +/// \return True on successful generation +bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size); + +}; // namespace TosaReference + +#endif // GENERATE_DOT_PRODUCT_H_ diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc new file mode 100644 index 0000000..cd9ffba --- /dev/null +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_dot_product.h" +#include "generate_utils.h" + +#include <cstdint> + +namespace +{ + +// Input index global variables +inline constexpr uint32_t P0 = 0; +inline constexpr uint32_t P1 = 1; + +// Unused helper function +template <typename... Args> +inline void unused(Args&&...) +{} + +// Primitive generator class +// +// Yields a new value on function operator access and increases the +// index by one +class PrimitiveGenerator +{ +public: + PrimitiveGenerator(uint32_t S) + : _S(S) + , _m(0) + , _r(0) + , _index(0) + { + _m = (8 * _S + 1) * 0x705A5E75; + _r = _m + 1; + } + + [[nodiscard]] float operator()() + { + _r = _r * _m + 1; + float sign = (_r >> 31) == 0 ? +1 : -1; + float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF); + ++_index; + + return pseudo; + } + + uint32_t index() + { + return _index; + } + +private: + uint32_t _S; + uint32_t _m; + uint32_t _r; + uint32_t _index; +}; + +//----------------------------------------------------------------------------// +// State generators +//----------------------------------------------------------------------------// + +// S0 generator +class GeneratorS0 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS0(uint32_t p) + : _p(p) + , _s0(0) // set_data(2*S) + , _s1(1) // set_data(2*S+1) + {} + float operator()(uint32_t k) override + { + unused(k); + const float s0 = _s0(); + const float s1 = _s1(); + if (_p == P0) + return s0 < 0.f ? 0.f : s1; + else + return s0 < 0.f ? s1 : 0.f; + } + +private: + uint32_t _p; + PrimitiveGenerator _s0; + PrimitiveGenerator _s1; +}; + +} // namespace + +namespace TosaReference +{ + +std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg) +{ + const DotProductInfo& dpinfo = cfg.dotProductInfo; + switch (dpinfo.s) + { + case 0: + return std::make_unique<GeneratorS0>(cfg.inputPos); + default: + return nullptr; + } + return nullptr; +} + +} // namespace TosaReference
\ No newline at end of file diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc new file mode 100644 index 0000000..95dbe8f --- /dev/null +++ b/reference_model/src/generate/generate_entry.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate.h" + +#include "generate_dot_product.h" +#include "generate_utils.h" + +#include "func_debug.h" +#include "model_common.h" + +namespace TosaReference +{ + +bool generate(const GenerateConfig& cfg, void* data, size_t size) +{ + switch (cfg.generatorType) + { + case GeneratorType::DotProduct: { + return generateDotProduct(cfg, data, size); + break; + } + default: { + WARNING("[Generator] Unsupported generation mode."); + break; + } + } + return false; +} + +} // namespace TosaReference + +extern "C" +{ + bool tgd_generate_data(const char* config_json, const char* tensor_name, void* data, size_t size) + { + // Check inputs for nullptr + if (!config_json || !tensor_name || !data) + { + WARNING("[Generator] One of the inputs is missing."); + return false; + } + + // Check JSON config validity + auto cfg = TosaReference::parseGenerateConfig(config_json, tensor_name); + if (!cfg) + { + WARNING("[Generator] Invalid json config."); + return false; + } + + // Check size + const size_t totalBytesNeeded = + TosaReference::numElementsFromShape(cfg->shape) * TosaReference::elementSizeFromType(cfg->dataType); + if (totalBytesNeeded > size) + { + WARNING("[Generator] Not enough space in provided buffer."); + return false; + } + + // Run generator + return generate(cfg.value(), data, size); + } +} // extern "C" diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc new file mode 100644 index 0000000..c52f051 --- /dev/null +++ b/reference_model/src/generate/generate_utils.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_utils.h" + +#include <nlohmann/json.hpp> + +#include <algorithm> + +namespace tosa +{ + +NLOHMANN_JSON_SERIALIZE_ENUM(DType, + { + { DType::DType_BOOL, "BOOL" }, + { DType::DType_INT4, "INT4" }, + { DType::DType_INT8, "INT8" }, + { DType::DType_INT16, "INT16" }, + { DType::DType_INT32, "INT32" }, + { DType::DType_INT48, "INT48" }, + { DType::DType_FP16, "FP16" }, + { DType::DType_BF16, "BF16" }, + { DType::DType_FP32, "FP32" }, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(Op, + { + { Op::Op_MATMUL, "MATMUL" }, + }) + +} // namespace tosa + +namespace TosaReference +{ + +NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, + { + { GeneratorType::PseudoRandom, "PSEUDO_RANDOM" }, + { GeneratorType::DotProduct, "DOT_PRODUCT" }, + { GeneratorType::OpFullRange, "OP_FULL_RANGE" }, + { GeneratorType::OpBoundary, "OP_BOUNDARY" }, + { GeneratorType::OpSpecial, "OP_SPECIAL" }, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(InputType, + { + { InputType::Variable, "VARIABLE" }, + { InputType::Constant, "CONSTANT" }, + }) + +void from_json(const nlohmann::json& j, DotProductInfo& dotProductInfo) +{ + j.at("s").get_to(dotProductInfo.s); + j.at("ks").get_to(dotProductInfo.ks); + j.at("acc_type").get_to(dotProductInfo.accType); + if (j.contains("kernel")) + { + j.at("kernel").get_to(dotProductInfo.kernel); + } + if (j.contains("axis")) + { + j.at("axis").get_to(dotProductInfo.axis); + } +} + +void from_json(const nlohmann::json& j, GenerateConfig& cfg) +{ + j.at("data_type").get_to(cfg.dataType); + j.at("input_type").get_to(cfg.inputType); + j.at("shape").get_to(cfg.shape); + j.at("input_pos").get_to(cfg.inputPos); + j.at("op").get_to(cfg.opType); + j.at("generator").get_to(cfg.generatorType); + if (j.contains("dot_product_info")) + { + j.at("dot_product_info").get_to(cfg.dotProductInfo); + } +} + +std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName) +{ + if (!tensorName) + return std::nullopt; + + auto jsonCfg = nlohmann::json::parse(json, nullptr, /* allow exceptions */ false); + + if (jsonCfg.is_discarded()) + return std::nullopt; + if (!jsonCfg.contains("tensors")) + return std::nullopt; + + const auto& tensors = jsonCfg["tensors"]; + if (!tensors.contains(tensorName)) + return std::nullopt; + + const auto& namedTensor = tensors[tensorName]; + return namedTensor.get<GenerateConfig>(); +} + +int64_t numElementsFromShape(const std::vector<int32_t>& shape) +{ + return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>()); +} + +size_t elementSizeFromType(DType type) +{ + switch (type) + { + case DType::DType_BOOL: + case DType::DType_UINT8: + case DType::DType_INT8: + return 1; + case DType::DType_UINT16: + case DType::DType_INT16: + case DType::DType_FP16: + case DType::DType_BF16: + return 2; + case DType::DType_INT32: + case DType::DType_FP32: + return 4; + default: + return 0; + } + return 0; +} +} // namespace TosaReference diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h new file mode 100644 index 0000000..2d5b7f8 --- /dev/null +++ b/reference_model/src/generate/generate_utils.h @@ -0,0 +1,80 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERATE_UTILS_H_ +#define GENERATE_UTILS_H_ + +#include "dtype.h" + +#include <array> +#include <cstdint> +#include <optional> +#include <vector> + +namespace TosaReference +{ + +/// \brief Supported generator types +enum class GeneratorType +{ + PseudoRandom, + DotProduct, + OpFullRange, + OpBoundary, + OpSpecial, +}; + +/// \brief Supported input types +enum class InputType +{ + Variable, + Constant, +}; + +/// \brief Dot-product generator meta-data +struct DotProductInfo +{ + DotProductInfo() = default; + + int32_t s; + int32_t ks; + DType accType; + int32_t axis; + std::array<int32_t, 2> kernel; +}; + +/// \brief Generator configuration +struct GenerateConfig +{ + GeneratorType generatorType; + DType dataType; + InputType inputType; + std::vector<int32_t> shape; + int32_t inputPos; + tosa::Op opType; + DotProductInfo dotProductInfo; +}; + +/// \brief Parse the generator config when given in JSON form +std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName); + +/// \brief Extract number of total elements +int64_t numElementsFromShape(const std::vector<int32_t>& shape); + +/// \brief Size in bytes of a given type +size_t elementSizeFromType(DType type); + +}; // namespace TosaReference + +#endif // GENERATE_UTILS_H_ diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp new file mode 100644 index 0000000..88dc979 --- /dev/null +++ b/reference_model/test/generate_tests.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "generate.h" + +#include <doctest.h> + +#include <array> +#include <string> +#include <vector> + +TEST_SUITE_BEGIN("generate"); + +TEST_CASE("negative - api") +{ + std::string json_cfg = R"({ + "tensors" : { + "in1" : { + "generator": "DOT_PRODUCT", + "data_type": "FP32", + "input_type": "VARIABLE", + "shape" : [ 4, 8, 8 ], + "input_pos": 0, + "op" : "MATMUL", + "dot_product_info": { + "s": 0, + "ks": 10, + "acc_type": "FP32" + } + } + } + })"; + + const std::string tosaName = "in1"; + const size_t tosaElements = 4 * 8 * 8; + const size_t tosaSize = tosaElements * 4; + + SUBCASE("missing input") + { + REQUIRE_FALSE(tgd_generate_data(NULL, NULL, NULL, 0)); + } + SUBCASE("invalid json") + { + std::string invalid_json_cfg = R"({ + "tensors" : { + "in1" : { + "generator": DOT_PRODUCT, + }, + } + })"; + + std::vector<float> buffer(tosaElements); + REQUIRE_FALSE(tgd_generate_data(invalid_json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize)); + } + SUBCASE("invalid json - mismatching name") + { + std::string invalidName = "notFound1"; + + std::vector<float> buffer(tosaElements); + REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), invalidName.c_str(), (void*)buffer.data(), tosaSize)); + } + SUBCASE("mismatching size") + { + size_t smallElements = 4 * 8 * 7; + size_t smallSize = smallElements * 4; + + std::vector<float> buffer(smallElements); + REQUIRE_FALSE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), smallSize)); + } +} + +TEST_CASE("positive - dot product") +{ + std::string json_cfg = R"({ + "tensors" : { + "in1" : { + "generator": "DOT_PRODUCT", + "data_type": "FP32", + "input_type": "VARIABLE", + "shape" : [ 4, 8, 8 ], + "input_pos": 0, + "op" : "MATMUL", + "dot_product_info": { + "s": 0, + "ks": 10, + "acc_type": "FP32" + } + } + } + })"; + + const std::string tosaName = "in1"; + const size_t tosaElements = 4 * 8 * 8; + const size_t tosaSize = tosaElements * 4; + + SUBCASE("matmul") + { + std::vector<float> buffer(tosaElements); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize)); + REQUIRE(buffer[0] == (float)-0.950864); + REQUIRE(buffer[1] == 0.f); + } +} + +TEST_SUITE_END(); // generate
\ No newline at end of file |