diff options
Diffstat (limited to 'src/backends/reference/workloads/Merger.cpp')
-rw-r--r-- | src/backends/reference/workloads/Merger.cpp | 49 |
1 files changed, 11 insertions, 38 deletions
diff --git a/src/backends/reference/workloads/Merger.cpp b/src/backends/reference/workloads/Merger.cpp index 8877ee2284..e0b70ee5cb 100644 --- a/src/backends/reference/workloads/Merger.cpp +++ b/src/backends/reference/workloads/Merger.cpp @@ -5,43 +5,19 @@ #include "Merger.hpp" #include "RefWorkloadUtils.hpp" +#include "Decoders.hpp" +#include "Encoders.hpp" namespace armnn { -template <> -void CopyValue<float>(const float& source, const TensorInfo& sourceInfo, float& dest, const TensorInfo& destInfo) -{ - dest = source; -} - -template <> -void CopyValue<uint8_t>(const uint8_t& source, const TensorInfo& sourceInfo, uint8_t& dest, const TensorInfo& destInfo) -{ - if (sourceInfo.GetQuantizationScale() != destInfo.GetQuantizationScale() || - sourceInfo.GetQuantizationOffset() != destInfo.GetQuantizationOffset()) - { - // Dequantize value according to sourceInfo params - float dequantizedValue = armnn::Dequantize<uint8_t>(source, - sourceInfo.GetQuantizationScale(), - sourceInfo.GetQuantizationOffset()); - - // Quantize again according to destInfo paramns - dest = armnn::Quantize<uint8_t>(dequantizedValue, - destInfo.GetQuantizationScale(), - destInfo.GetQuantizationOffset()); - } - else - { - dest = source; - } -} - -template <typename DataType> void Merger(const MergerQueueDescriptor& data) { const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]); + std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo0, data.m_Outputs[0]->Map()); + Encoder<float>& encoder = *encoderPtr; + for (unsigned int index = 0 ; index < outputInfo0.GetNumElements(); ++index) { unsigned int indices[MaxNumOfTensorDimensions] = { 0 }; @@ -80,6 +56,9 @@ void Merger(const MergerQueueDescriptor& data) if (insideView) { + std::unique_ptr<Decoder<float>> decoderPtr = + MakeDecoder<float>(inputInfo, data.m_Inputs[viewIdx]->Map()); + Decoder<float>& decoder = *decoderPtr; unsigned int inIndex = 0; unsigned int dimensionStride = 1; @@ -88,11 +67,8 @@ void Merger(const MergerQueueDescriptor& data) inIndex += dimensionStride * (indices[i] - view.m_Origin[i]); dimensionStride *= inputInfo.GetShape()[i]; } - - CopyValue<DataType>((GetInputTensorData<DataType>(viewIdx, data))[inIndex], - GetTensorInfo(data.m_Inputs[viewIdx]), - (GetOutputTensorData<DataType>(0, data))[index], - outputInfo0); + decoder += inIndex; + encoder.Set(decoder.Get()); //What should we do if input views overlap on the output tensor? //We could error, take the average, or shm else... @@ -100,11 +76,8 @@ void Merger(const MergerQueueDescriptor& data) break; } } + ++encoder; } } -template void Merger<float>(const MergerQueueDescriptor& data); - -template void Merger<uint8_t>(const MergerQueueDescriptor& data); - } //namespace armnn |