aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefQuantizeWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefQuantizeWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefQuantizeWorkload.cpp55
1 files changed, 19 insertions, 36 deletions
diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.cpp b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
index ab2ee7fc4e..2eef5f33db 100644
--- a/src/backends/reference/workloads/RefQuantizeWorkload.cpp
+++ b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
@@ -5,6 +5,8 @@
#include "RefQuantizeWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+
#include <armnn/TypesUtils.hpp>
@@ -14,14 +16,13 @@ namespace armnn
namespace
{
-template<typename T>
-void QuantizeImpl(const void *input, void *output, size_t numValues, float scale, int offset)
+void QuantizeImpl(Decoder<float>& in, Encoder<float>& out, size_t numValues)
{
- auto in = static_cast<const float *>(input);
- auto out = static_cast<T *>(output);
- for (size_t i = 0; i < numValues; i++, in++, out++)
+ for (unsigned int i = 0; i < numValues; i++)
{
- *out = armnn::Quantize<T>(*in, scale, offset);
+ in[i];
+ out[i];
+ out.Set(in.Get());
}
}
@@ -30,42 +31,24 @@ void QuantizeImpl(const void *input, void *output, size_t numValues, float scale
RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info)
: BaseWorkload(descriptor, info)
, m_NumElements(info.m_InputTensorInfos[0].GetNumElements())
- , m_TargetType(info.m_OutputTensorInfos[0].GetDataType())
- , m_Scale(info.m_OutputTensorInfos[0].GetQuantizationScale())
- , m_Offset(info.m_OutputTensorInfos[0].GetQuantizationOffset())
{
}
-void RefQuantizeWorkload::Execute() const
+void RefQuantizeWorkload::PostAllocationConfigure()
{
- const void* input = m_Data.m_Inputs[0]->Map(true);
- void* output = m_Data.m_Outputs[0]->Map(true);
+ const TensorInfo& inputInfo = armnn::GetTensorInfo(m_Data.m_Inputs[0]);
+ m_InputDecoder = MakeDecoder<float>(inputInfo);
- switch(m_TargetType)
- {
- case DataType::QAsymmU8:
- {
- QuantizeImpl<uint8_t>(input, output, m_NumElements, m_Scale, m_Offset);
- break;
- }
- case DataType::QSymmS8:
- {
- QuantizeImpl<int8_t>(input, output, m_NumElements, m_Scale, 0);
- break;
- }
- case DataType::QSymmS16:
- {
- QuantizeImpl<int16_t>(input, output, m_NumElements, m_Scale, 0);
- break;
- }
- default:
- {
- BOOST_ASSERT_MSG(false, "RefQuantizeWorkload: Non quantized output type encountered");
- }
- }
+ const TensorInfo& outputInfo = armnn::GetTensorInfo(m_Data.m_Outputs[0]);
+ m_OutputEncoder = MakeEncoder<float>(outputInfo);
+}
+
+void RefQuantizeWorkload::Execute() const
+{
+ m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
+ m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
- m_Data.m_Inputs[0]->Unmap();
- m_Data.m_Outputs[0]->Unmap();
+ QuantizeImpl(*m_InputDecoder, *m_OutputEncoder, m_NumElements);
}
} //namespace armnn \ No newline at end of file