diff options
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 108 |
1 files changed, 41 insertions, 67 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 1a30e7c9fb..535adca0d7 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -4,17 +4,41 @@ // #include "RefElementwiseWorkload.hpp" + +#include "Decoders.hpp" #include "ElementwiseFunction.hpp" -#include "RefWorkloadUtils.hpp" +#include "Encoders.hpp" #include "Profiling.hpp" +#include "RefWorkloadUtils.hpp" #include "StringMapping.hpp" #include "TypeUtils.hpp" + #include <vector> namespace armnn { template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWorkload( + const ParentDescriptor& desc, + const WorkloadInfo& info) + : BaseWorkload<ParentDescriptor>(desc, info) +{ +} + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::PostAllocationConfigure() +{ + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + m_Input0 = MakeDecoder<InType>(inputInfo0, m_Data.m_Inputs[0]->Map()); + m_Input1 = MakeDecoder<InType>(inputInfo1, m_Data.m_Inputs[1]->Map()); + m_Output = MakeEncoder<OutType>(outputInfo, m_Data.m_Outputs[0]->Map()); +} + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); @@ -26,73 +50,15 @@ void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() c const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - switch(inputInfo0.GetDataType()) - { - case armnn::DataType::QuantisedAsymm8: - { - QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), - inputInfo0.GetQuantizationScale(), - inputInfo0.GetQuantizationOffset()); - - QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), - inputInfo1.GetQuantizationScale(), - inputInfo1.GetQuantizationOffset()); - - QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data), - outputInfo.GetQuantizationScale(), - outputInfo.GetQuantizationOffset()); - - ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, - inShape1, - outShape, - decodeIterator0, - decodeIterator1, - encodeIterator0); - break; - } - case armnn::DataType::Float32: - { - FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); - FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); - FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data)); - - ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, - inShape1, - outShape, - decodeIterator0, - decodeIterator1, - encodeIterator0); - break; - } - case armnn::DataType::QuantisedSymm16: - { - QSymm16Decoder decodeIterator0(GetInputTensorData<int16_t>(0, m_Data), - inputInfo0.GetQuantizationScale(), - inputInfo0.GetQuantizationOffset()); - - QSymm16Decoder decodeIterator1(GetInputTensorData<int16_t>(1, m_Data), - inputInfo1.GetQuantizationScale(), - inputInfo1.GetQuantizationOffset()); - - QSymm16Encoder encodeIterator0(GetOutputTensorData<int16_t>(0, m_Data), - outputInfo.GetQuantizationScale(), - outputInfo.GetQuantizationOffset()); - - ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, - inShape1, - outShape, - decodeIterator0, - decodeIterator1, - encodeIterator0); - break; - } - default: - BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!"); - break; - } + ElementwiseFunction<Functor>(inShape0, + inShape1, + outShape, + *m_Input0, + *m_Input1, + *m_Output); } -} +} //namespace armnn template class armnn::RefElementwiseWorkload<std::plus<float>, armnn::AdditionQueueDescriptor, @@ -116,4 +82,12 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>, template class armnn::RefElementwiseWorkload<armnn::minimum<float>, armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>;
\ No newline at end of file + armnn::StringMapping::RefMinimumWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::equal_to<float>, + armnn::EqualQueueDescriptor, + armnn::StringMapping::RefEqualWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::greater<float>, + armnn::GreaterQueueDescriptor, + armnn::StringMapping::RefGreaterWorkload_Execute>; |