aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_fixed_data.cc
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-28 13:20:05 +0000
committerTatWai Chong <tatwai.chong@arm.com>2024-03-01 13:16:56 -0800
commit0a042997ac24fee1a338e806caf18bd8dfba28f3 (patch)
tree1cfe325d7d775b778873a3940407e68d39c80a48 /reference_model/src/generate/generate_fixed_data.cc
parent3195a665e3f96809a67b4cb04a57330d2bfeb0de (diff)
downloadreference_model-0a042997ac24fee1a338e806caf18bd8dfba28f3.tar.gz
Testing support for MUL with shift as input
Always create the shift as a tensor for all types in testing. In the reference model, set the shift operand to be available for all types, but only read in the shift tensor for i32. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
Diffstat (limited to 'reference_model/src/generate/generate_fixed_data.cc')
-rw-r--r--reference_model/src/generate/generate_fixed_data.cc41
1 files changed, 28 insertions, 13 deletions
diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc
index 3d4ee3e..b0b6c81 100644
--- a/reference_model/src/generate/generate_fixed_data.cc
+++ b/reference_model/src/generate/generate_fixed_data.cc
@@ -20,8 +20,22 @@
#include <type_traits>
#include <vector>
+namespace
+{
+template <typename OutType>
+bool copyFixedData(const int64_t elements, const std::vector<int32_t> inData, OutType* outData)
+{
+ for (auto t = 0; t < elements; t++)
+ {
+ outData[t] = inData[t];
+ }
+ return true;
+}
+} // namespace
+
namespace TosaReference
{
+
bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size)
{
// Check we support the operator
@@ -31,22 +45,23 @@ bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size)
return false;
}
+ std::vector<int32_t> inData = cfg.fixedDataInfo.data;
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ if (T != static_cast<int64_t>(inData.size()))
+ {
+ WARNING("[Generator][FD] Given data size %d does not match output size %d.", inData.size(), T);
+ return false;
+ }
+
switch (cfg.dataType)
{
case DType::DType_SHAPE: {
- int32_t* outData = reinterpret_cast<int32_t*>(data);
- std::vector<int32_t> inData = cfg.fixedDataInfo.data;
- const auto T = TosaReference::numElementsFromShape(cfg.shape);
- if (T != static_cast<int64_t>(inData.size()))
- {
- WARNING("[Generator][FD] Size does not match.");
- return false;
- }
- for (auto t = 0; t < T; t++)
- {
- outData[t] = inData[t];
- }
- return true;
+ int32_t* outData = reinterpret_cast<int32_t*>(data);
+ return copyFixedData(T, inData, outData);
+ }
+ case DType::DType_INT8: {
+ int8_t* outData = reinterpret_cast<int8_t*>(data);
+ return copyFixedData(T, inData, outData);
}
default:
WARNING("[Generator][FD] Unsupported type.");