From 0a042997ac24fee1a338e806caf18bd8dfba28f3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 28 Feb 2024 13:20:05 +0000 Subject: Testing support for MUL with shift as input Always create the shift as a tensor for all types in testing. In the reference model, set the shift operand to be available for all types, but only read in the shift tensor for i32. Signed-off-by: Jeremy Johnson Signed-off-by: TatWai Chong Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538 --- .../src/generate/generate_fixed_data.cc | 41 +++++++++++++++------- 1 file changed, 28 insertions(+), 13 deletions(-) (limited to 'reference_model/src/generate/generate_fixed_data.cc') diff --git a/reference_model/src/generate/generate_fixed_data.cc b/reference_model/src/generate/generate_fixed_data.cc index 3d4ee3e..b0b6c81 100644 --- a/reference_model/src/generate/generate_fixed_data.cc +++ b/reference_model/src/generate/generate_fixed_data.cc @@ -20,8 +20,22 @@ #include #include +namespace +{ +template +bool copyFixedData(const int64_t elements, const std::vector inData, OutType* outData) +{ + for (auto t = 0; t < elements; t++) + { + outData[t] = inData[t]; + } + return true; +} +} // namespace + namespace TosaReference { + bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size) { // Check we support the operator @@ -31,22 +45,23 @@ bool generateFixedData(const GenerateConfig& cfg, void* data, size_t size) return false; } + std::vector inData = cfg.fixedDataInfo.data; + const auto T = TosaReference::numElementsFromShape(cfg.shape); + if (T != static_cast(inData.size())) + { + WARNING("[Generator][FD] Given data size %d does not match output size %d.", inData.size(), T); + return false; + } + switch (cfg.dataType) { case DType::DType_SHAPE: { - int32_t* outData = reinterpret_cast(data); - std::vector inData = cfg.fixedDataInfo.data; - const auto T = TosaReference::numElementsFromShape(cfg.shape); - if (T != static_cast(inData.size())) - { - WARNING("[Generator][FD] Size does not match."); - return false; - } - for (auto t = 0; t < T; t++) - { - outData[t] = inData[t]; - } - return true; + int32_t* outData = reinterpret_cast(data); + return copyFixedData(T, inData, outData); + } + case DType::DType_INT8: { + int8_t* outData = reinterpret_cast(data); + return copyFixedData(T, inData, outData); } default: WARNING("[Generator][FD] Unsupported type."); -- cgit v1.2.1