aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp161
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp51
2 files changed, 109 insertions, 103 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index c9b93c8524..356d7a0c16 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -7,69 +7,118 @@
#include "ElementwiseFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
+#include "StringMapping.hpp"
+#include "TypeUtils.hpp"
#include <vector>
namespace armnn
{
-template <typename ParentDescriptor, typename Functor>
-void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const
+template <typename Functor,
+ typename armnn::DataType DataType,
+ typename ParentDescriptor,
+ typename armnn::StringMapping::Id DebugString>
+void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
-
- auto data = Float32Workload<ParentDescriptor>::GetData();
- const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
- const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
- const TensorShape& outShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
-
- const float* inData0 = GetInputTensorDataFloat(0, data);
- const float* inData1 = GetInputTensorDataFloat(1, data);
- float* outData = GetOutputTensorDataFloat(0, data);
-
- ElementwiseFunction<Functor, float, float>(inShape0, inShape1, outShape, inData0, inData1, outData);
-}
-
-template <typename ParentDescriptor, typename Functor>
-void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
-
- auto data = Uint8Workload<ParentDescriptor>::GetData();
- const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
- const TensorInfo& inputInfo1 = GetTensorInfo(data.m_Inputs[1]);
- const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]);
-
- auto dequant0 = Dequantize(GetInputTensorDataU8(0, data), inputInfo0);
- auto dequant1 = Dequantize(GetInputTensorDataU8(1, data), inputInfo1);
-
- std::vector<float> results(outputInfo.GetNumElements());
-
- ElementwiseFunction<Functor, float, float>(inputInfo0.GetShape(),
- inputInfo1.GetShape(),
- outputInfo.GetShape(),
- dequant0.data(),
- dequant1.data(),
- results.data());
-
- Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo);
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ switch(DataType)
+ {
+ case armnn::DataType::QuantisedAsymm8:
+ {
+ std::vector<float> results(outputInfo.GetNumElements());
+ ElementwiseFunction<Functor, float, float>(inShape0,
+ inShape1,
+ outShape,
+ Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(),
+ Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(),
+ results.data());
+ Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
+ break;
+ }
+ case armnn::DataType::Float32:
+ {
+ ElementwiseFunction<Functor, float, float>(inShape0,
+ inShape1,
+ outShape,
+ GetInputTensorDataFloat(0, m_Data),
+ GetInputTensorDataFloat(1, m_Data),
+ GetOutputTensorDataFloat(0, m_Data));
+ break;
+ }
+ default:
+ BOOST_ASSERT_MSG(false, "Unknown Data Type!");
+ break;
+ }
}
}
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>;
-
-template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
-template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
+template class armnn::RefElementwiseWorkload<std::plus<float>,
+ armnn::DataType::Float32,
+ armnn::AdditionQueueDescriptor,
+ armnn::StringMapping::RefAdditionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::plus<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::AdditionQueueDescriptor,
+ armnn::StringMapping::RefAdditionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::minus<float>,
+ armnn::DataType::Float32,
+ armnn::SubtractionQueueDescriptor,
+ armnn::StringMapping::RefSubtractionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::minus<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::SubtractionQueueDescriptor,
+ armnn::StringMapping::RefSubtractionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::multiplies<float>,
+ armnn::DataType::Float32,
+ armnn::MultiplicationQueueDescriptor,
+ armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::multiplies<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::MultiplicationQueueDescriptor,
+ armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::divides<float>,
+ armnn::DataType::Float32,
+ armnn::DivisionQueueDescriptor,
+ armnn::StringMapping::RefDivisionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::divides<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::DivisionQueueDescriptor,
+ armnn::StringMapping::RefDivisionWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
+ armnn::DataType::Float32,
+ armnn::MaximumQueueDescriptor,
+ armnn::StringMapping::RefMaximumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::MaximumQueueDescriptor,
+ armnn::StringMapping::RefMaximumWorkload_Execute>;
+
+
+template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
+ armnn::DataType::Float32,
+ armnn::MinimumQueueDescriptor,
+ armnn::StringMapping::RefMinimumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
+ armnn::DataType::QuantisedAsymm8,
+ armnn::MinimumQueueDescriptor,
+ armnn::StringMapping::RefMinimumWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index a5ff376673..371904977a 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -20,56 +20,14 @@ template <typename Functor,
typename ParentDescriptor,
typename armnn::StringMapping::Id DebugString>
class RefElementwiseWorkload
-{
- // Needs specialization. The default is empty on purpose.
-};
-
-template <typename ParentDescriptor, typename Functor>
-class BaseFloat32ElementwiseWorkload : public Float32Workload<ParentDescriptor>
-{
-public:
- using Float32Workload<ParentDescriptor>::Float32Workload;
- void ExecuteImpl(const char * debugString) const;
-};
-
-template <typename Functor,
- typename ParentDescriptor,
- typename armnn::StringMapping::Id DebugString>
-class RefElementwiseWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString>
- : public BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>
-{
-public:
- using BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::BaseFloat32ElementwiseWorkload;
-
- virtual void Execute() const override
- {
- using Parent = BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>;
- Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
- }
-};
-
-template <typename ParentDescriptor, typename Functor>
-class BaseUint8ElementwiseWorkload : public Uint8Workload<ParentDescriptor>
+ : public TypedWorkload<ParentDescriptor, DataType>
{
public:
- using Uint8Workload<ParentDescriptor>::Uint8Workload;
- void ExecuteImpl(const char * debugString) const;
-};
-template <typename Functor,
- typename ParentDescriptor,
- typename armnn::StringMapping::Id DebugString>
-class RefElementwiseWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString>
- : public BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>
-{
-public:
- using BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::BaseUint8ElementwiseWorkload;
+ using TypedWorkload<ParentDescriptor, DataType>::m_Data;
+ using TypedWorkload<ParentDescriptor, DataType>::TypedWorkload;
- virtual void Execute() const override
- {
- using Parent = BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>;
- Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
- }
+ void Execute() const override;
};
using RefAdditionFloat32Workload =
@@ -120,7 +78,6 @@ using RefDivisionUint8Workload =
DivisionQueueDescriptor,
StringMapping::RefDivisionWorkload_Execute>;
-
using RefMaximumFloat32Workload =
RefElementwiseWorkload<armnn::maximum<float>,
DataType::Float32,