aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_fixed_data.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_fixed_data.cc')
-rw-r--r--reference_model/src/generate/generate_fixed_data.cc41
1 files changed, 28 insertions, 13 deletions
diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc
index 3d4ee3e..b0b6c81 100644
--- a/reference_model/src/generate/generate_fixed_data.cc
+++ b/reference_model/src/generate/generate_fixed_data.cc
@@ -20,8 +20,22 @@
#include <type_traits>
#include <vector>
+namespace
+{
+template <typename OutType>
+bool copyFixedData(const int64_t elements, const std::vector<int32_t> inData, OutType* outData)
+{
+ for (auto t = 0; t < elements; t++)
+ {
+ outData[t] = inData[t];
+ }
+ return true;
+}
+} // namespace
+
namespace TosaReference
{
+
bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size)
{
// Check we support the operator
@@ -31,22 +45,23 @@ bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size)
return false;
}
+ std::vector<int32_t> inData = cfg.fixedDataInfo.data;
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ if (T != static_cast<int64_t>(inData.size()))
+ {
+ WARNING("[Generator][FD] Given data size %d does not match output size %d.", inData.size(), T);
+ return false;
+ }
+
switch (cfg.dataType)
{
case DType::DType_SHAPE: {
- int32_t* outData = reinterpret_cast<int32_t*>(data);
- std::vector<int32_t> inData = cfg.fixedDataInfo.data;
- const auto T = TosaReference::numElementsFromShape(cfg.shape);
- if (T != static_cast<int64_t>(inData.size()))
- {
- WARNING("[Generator][FD] Size does not match.");
- return false;
- }
- for (auto t = 0; t < T; t++)
- {
- outData[t] = inData[t];
- }
- return true;
+ int32_t* outData = reinterpret_cast<int32_t*>(data);
+ return copyFixedData(T, inData, outData);
+ }
+ case DType::DType_INT8: {
+ int8_t* outData = reinterpret_cast<int8_t*>(data);
+ return copyFixedData(T, inData, outData);
}
default:
WARNING("[Generator][FD] Unsupported type.");