aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-04-09 10:25:02 +0100
committerDerek Lamberti <derek.lamberti@arm.com>2019-04-10 15:13:41 +0100
commitf30f7d32b22020f80b21da7b008d8302cee9d395 (patch)
tree2e213da4704c46b40f20629223365d1ddbf8d8cd
parent82fbe7c0b82f7adadd5120ac4b4f779d0da7c0d5 (diff)
downloadarmnn-f30f7d32b22020f80b21da7b008d8302cee9d395.tar.gz
IVGCVSW-2946 RefElementwiseWorkload configures prior to first execute
+ Added PostAllocationConfigure() method to workload interface + Elementwise function now deduces types based on Functor - Replaced RefComparisonWorkload with RefElementwiseWorkload specialization + Fixed up unit tests and minor formatting Change-Id: I33d08797767bba01cf4efb2904920ce0f950a4fe Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r--src/armnn/LoadedNetwork.cpp6
-rw-r--r--src/backends/backendsCommon/Workload.hpp3
-rw-r--r--src/backends/backendsCommon/test/LayerTests.cpp99
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp89
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/Decoders.hpp48
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp31
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp11
-rw-r--r--src/backends/reference/workloads/Encoders.hpp66
-rw-r--r--src/backends/reference/workloads/Maximum.hpp2
-rw-r--r--src/backends/reference/workloads/Minimum.hpp2
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp76
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp36
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp108
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp22
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
17 files changed, 324 insertions, 281 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 9263f1a6e9..7f00dbee87 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -136,6 +136,12 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
// Set up memory.
m_OptimizedNetwork->GetGraph().AllocateDynamicBuffers();
+
+ // Now that the intermediate tensor memory has been set-up, do any post allocation configuration for each workload.
+ for (auto& workload : m_WorkloadQueue)
+ {
+ workload->PostAllocationConfigure();
+ }
}
TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 447ec1b4d6..3efd7dbfd4 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -20,6 +20,7 @@ class IWorkload
public:
virtual ~IWorkload() {}
+ virtual void PostAllocationConfigure() = 0;
virtual void Execute() const = 0;
virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
@@ -44,6 +45,8 @@ public:
m_Data.Validate(info);
}
+ void PostAllocationConfigure() override {}
+
const QueueDescriptor& GetData() const { return m_Data; }
protected:
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index 74f3997133..cba4d3d93a 100644
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -1177,6 +1177,7 @@ LayerTestResult<float,3> MergerTest(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0], outputHandle.get());
@@ -1268,6 +1269,7 @@ LayerTestResult<float,4> AdditionTest(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -1346,6 +1348,7 @@ LayerTestResult<T, 4> AdditionBroadcastTestImpl(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -1419,6 +1422,7 @@ LayerTestResult<T, 4> AdditionBroadcast1ElementTestImpl(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -1533,7 +1537,9 @@ LayerTestResult<float,4> CompareAdditionTest(
CopyDataToITensorHandle(inputHandle1Ref.get(), &input1[0][0][0][0]);
CopyDataToITensorHandle(inputHandle2Ref.get(), &input2[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
+ workloadRef->PostAllocationConfigure();
workloadRef->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -1598,6 +1604,7 @@ LayerTestResult<T, 4> DivisionTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -1952,6 +1959,7 @@ LayerTestResult<TOutput, 4> ElementwiseTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
ExecuteWorkload(*workload, memoryManager);
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -2770,6 +2778,7 @@ LayerTestResult<float,4> MultiplicationTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -2923,9 +2932,10 @@ LayerTestResult<float,4> CompareMultiplicationTest(
CopyDataToITensorHandle(inputHandle0Ref.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1Ref.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
+ workloadRef->PostAllocationConfigure();
workloadRef->Execute();
-
CopyDataFromITensorHandle(&comparisonResult.output[0][0][0][0], outputHandle.get());
CopyDataFromITensorHandle(&comparisonResult.outputExpected[0][0][0][0], outputHandleRef.get());
@@ -3004,7 +3014,9 @@ LayerTestResult<float,4> CompareBatchNormTest(
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
CopyDataToITensorHandle(inputHandleRef.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
+ workloadRef->PostAllocationConfigure();
workloadRef->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -3049,6 +3061,7 @@ void PermuteTensorData(
CopyDataToITensorHandle(inputHandle.get(), inputData);
+ workload->PostAllocationConfigure();
workload->Execute();
outputData.resize(outputTensorInfo.GetNumElements());
@@ -3381,6 +3394,7 @@ void Concatenate(
++nextInputId;
}
+ workload->PostAllocationConfigure();
workload->Execute();
if (needPermuteForConcat)
@@ -5002,6 +5016,7 @@ LayerTestResult<float, 4> ResizeBilinearNopTest(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5071,6 +5086,7 @@ LayerTestResult<float, 4> SimpleResizeBilinearTest(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5140,6 +5156,7 @@ LayerTestResult<float, 4> ResizeBilinearSqMinTest(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5207,6 +5224,7 @@ LayerTestResult<float, 4> ResizeBilinearMinTest(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5276,6 +5294,7 @@ LayerTestResult<float, 4> ResizeBilinearMagTest(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5326,6 +5345,7 @@ LayerTestResult<float, 2> FakeQuantizationTest(
CopyDataToITensorHandle(inputHandle.get(), &input[0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0], outputHandle.get());
@@ -5392,6 +5412,7 @@ LayerTestResult<float, 4> L2NormalizationTestImpl(
CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0][0][0]);
+ workload->PostAllocationConfigure();
ExecuteWorkload(*workload, memoryManager);
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -5415,13 +5436,13 @@ LayerTestResult<T, 2> Pad2dTestCommon(
float qScale,
int32_t qOffset)
{
- const armnn::TensorShape inputShape{ 3, 3 };
- const armnn::TensorShape outputShape{ 7, 7 };
+ const armnn::TensorShape inputShape{ 3, 3 };
+ const armnn::TensorShape outputShape{ 7, 7 };
- const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
- const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+ const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
+ const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
- std::vector<T> inputValues(
+ std::vector<T> inputValues(
QuantizedVector<T>(qScale, qOffset,
{
// Height (3) x Width (3)
@@ -5430,8 +5451,8 @@ LayerTestResult<T, 2> Pad2dTestCommon(
3, 2, 4
}));
- std::vector<T> expectedOutputValues(
- QuantizedVector<T>(qScale, qOffset,
+ std::vector<T> expectedOutputValues(
+ QuantizedVector<T>(qScale, qOffset,
{
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
@@ -5442,38 +5463,39 @@ LayerTestResult<T, 2> Pad2dTestCommon(
0, 0, 0, 0, 0, 0, 0
}));
- auto inputTensor = MakeTensor<T, 2>(inputTensorInfo, std::vector<T>(inputValues));
+ auto inputTensor = MakeTensor<T, 2>(inputTensorInfo, std::vector<T>(inputValues));
- LayerTestResult<T, 2> result(outputTensorInfo);
- result.outputExpected = MakeTensor<T, 2>(outputTensorInfo, std::vector<T>(expectedOutputValues));
+ LayerTestResult<T, 2> result(outputTensorInfo);
+ result.outputExpected = MakeTensor<T, 2>(outputTensorInfo, std::vector<T>(expectedOutputValues));
- std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
- std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+ std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
+ std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
- armnn::PadQueueDescriptor descriptor;
+ armnn::PadQueueDescriptor descriptor;
- std::vector<std::pair<unsigned int, unsigned int>> PadList;
- PadList.push_back(std::pair<unsigned int, unsigned int>(2,2));
- PadList.push_back(std::pair<unsigned int, unsigned int>(2,2));
+ std::vector<std::pair<unsigned int, unsigned int>> PadList;
+ PadList.push_back(std::pair<unsigned int, unsigned int>(2,2));
+ PadList.push_back(std::pair<unsigned int, unsigned int>(2,2));
- descriptor.m_Parameters.m_PadList = PadList;
- armnn::WorkloadInfo info;
+ descriptor.m_Parameters.m_PadList = PadList;
+ armnn::WorkloadInfo info;
- AddInputToWorkload(descriptor, info, inputTensorInfo, inputHandle.get());
- AddOutputToWorkload(descriptor, info, outputTensorInfo, outputHandle.get());
+ AddInputToWorkload(descriptor, info, inputTensorInfo, inputHandle.get());
+ AddOutputToWorkload(descriptor, info, outputTensorInfo, outputHandle.get());
- std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePad(descriptor, info);
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreatePad(descriptor, info);
- inputHandle->Allocate();
- outputHandle->Allocate();
+ inputHandle->Allocate();
+ outputHandle->Allocate();
- CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0]);
+ CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0]);
- workload->Execute();
+ workload->PostAllocationConfigure();
+ workload->Execute();
- CopyDataFromITensorHandle(&result.output[0][0], outputHandle.get());
+ CopyDataFromITensorHandle(&result.output[0][0], outputHandle.get());
- return result;
+ return result;
}
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
@@ -5553,6 +5575,7 @@ LayerTestResult<T, 3> Pad3dTestCommon(
CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0], outputHandle.get());
@@ -5790,6 +5813,7 @@ LayerTestResult<T, 4> Pad4dTestCommon(
CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -6251,6 +6275,7 @@ LayerTestResult<T, 4> ConstantTestImpl(
outputHandle->Allocate();
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -6406,6 +6431,7 @@ LayerTestResult<uint8_t, 3> MergerUint8DifferentQParamsTest(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0], outputHandle.get());
@@ -6541,6 +6567,7 @@ LayerTestResult<uint8_t, 3> MergerUint8Test(
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0]);
CopyDataToITensorHandle(inputHandle2.get(), &input2[0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&ret.output[0][0][0], outputHandle.get());
@@ -6610,6 +6637,7 @@ LayerTestResult<T, 4> AdditionQuantizeTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -6739,6 +6767,7 @@ LayerTestResult<T, 4> MultiplicationQuantizeTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7034,6 +7063,7 @@ LayerTestResult<T, 4> SubtractionTestHelper(
CopyDataToITensorHandle(inputHandle0.get(), &input0[0][0][0][0]);
CopyDataToITensorHandle(inputHandle1.get(), &input1[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7252,6 +7282,7 @@ LayerTestResult<uint8_t, 4> ResizeBilinearNopUint8Test(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7311,6 +7342,7 @@ LayerTestResult<uint8_t, 4> SimpleResizeBilinearUint8Test(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7368,6 +7400,7 @@ LayerTestResult<uint8_t, 4> ResizeBilinearSqMinUint8Test(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7423,6 +7456,7 @@ LayerTestResult<uint8_t, 4> ResizeBilinearMinUint8Test(
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7480,6 +7514,7 @@ LayerTestResult<uint8_t, 4> ResizeBilinearMagUint8Test(
outputHandle->Allocate();
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
@@ -7516,6 +7551,7 @@ LayerTestResult<float, 2> Rsqrt2dTestCommon(
CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0], outputHandle.get());
@@ -7593,6 +7629,7 @@ LayerTestResult<float, 3> Rsqrt3dTest(
CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0], outputHandle.get());
@@ -8384,6 +8421,7 @@ LayerTestResult<T, OutputDim> MeanTestHelper(
CopyDataToITensorHandle(inputHandle.get(), input.origin());
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
@@ -8660,7 +8698,9 @@ LayerTestResult<float, 4> AdditionAfterMaxPoolTest(
CopyDataToITensorHandle(poolingOutputHandle.get(), &resultMaxPool[0][0][0][0]);
CopyDataToITensorHandle(addInputHandle.get(), &addInput[0][0][0][0]);
+ workload->PostAllocationConfigure();
workload->Execute();
+ addWorkload->PostAllocationConfigure();
addWorkload->Execute();
CopyDataFromITensorHandle(&addRet.output[0][0][0][0], addOutputHandle.get());
@@ -8795,7 +8835,7 @@ LayerTestResult<T, OutputDim> BatchToSpaceNdHelper(
const std::vector<T> &outputData,
float scale = 1.0f,
int32_t offset = 0)
- {
+{
auto dataType = (std::is_same<T, uint8_t>::value ? armnn::DataType::QuantisedAsymm8 : armnn::DataType::Float32);
armnn::TensorInfo inputTensorInfo(InputDim, inputShape, dataType);
@@ -8830,6 +8870,7 @@ LayerTestResult<T, OutputDim> BatchToSpaceNdHelper(
CopyDataToITensorHandle(inputHandle.get(), input.origin());
+ workload->PostAllocationConfigure();
workload->Execute();
CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index f2b1153a71..e74e85378f 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -30,7 +30,6 @@ BACKEND_SOURCES := \
workloads/RefBatchNormalizationUint8Workload.cpp \
workloads/RefBatchToSpaceNdFloat32Workload.cpp \
workloads/RefBatchToSpaceNdUint8Workload.cpp \
- workloads/RefComparisonWorkload.cpp \
workloads/RefConstantWorkload.cpp \
workloads/RefConvertFp16ToFp32Workload.cpp \
workloads/RefConvertFp32ToFp16Workload.cpp \
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 95c75a576a..955d7f2185 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -25,34 +25,30 @@ public:
virtual BaseIterator& operator-=(const unsigned int increment) = 0;
};
+template<typename IType>
class Decoder : public BaseIterator
{
public:
- Decoder() : BaseIterator() {}
+ using InterfaceType = IType;
+
+ Decoder() {}
virtual ~Decoder() {}
- virtual float Get() const = 0;
+ virtual IType Get() const = 0;
};
+template<typename IType>
class Encoder : public BaseIterator
{
public:
- Encoder() : BaseIterator() {}
-
- virtual ~Encoder() {}
+ using InterfaceType = IType;
- virtual void Set(const float& right) = 0;
-};
-
-class ComparisonEncoder : public BaseIterator
-{
-public:
- ComparisonEncoder() : BaseIterator() {}
+ Encoder() {}
- virtual ~ComparisonEncoder() {}
+ virtual ~Encoder() {}
- virtual void Set(bool right) = 0;
+ virtual void Set(IType right) = 0;
};
template<typename T, typename Base>
@@ -84,7 +80,7 @@ public:
T* m_Iterator;
};
-class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
+class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
{
public:
QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
@@ -100,19 +96,7 @@ private:
const int32_t m_Offset;
};
-class FloatDecoder : public TypedIterator<const float, Decoder>
-{
-public:
- FloatDecoder(const float* data)
- : TypedIterator(data) {}
-
- float Get() const override
- {
- return *m_Iterator;
- }
-};
-
-class QSymm16Decoder : public TypedIterator<const int16_t, Decoder>
+class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
{
public:
QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
@@ -128,25 +112,25 @@ private:
const int32_t m_Offset;
};
-class FloatEncoder : public TypedIterator<float, Encoder>
+class FloatDecoder : public TypedIterator<const float, Decoder<float>>
{
public:
- FloatEncoder(float* data)
+ FloatDecoder(const float* data)
: TypedIterator(data) {}
- void Set(const float& right) override
+ float Get() const override
{
- *m_Iterator = right;
+ return *m_Iterator;
}
};
-class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
+class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
{
public:
QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
- void Set(const float& right) override
+ void Set(float right) override
{
*m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
}
@@ -156,32 +140,45 @@ private:
const int32_t m_Offset;
};
-class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
+class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
{
public:
- BooleanEncoder(uint8_t* data)
+ QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ void Set(float right) override
+ {
+ *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
+class FloatEncoder : public TypedIterator<float, Encoder<float>>
+{
+public:
+ FloatEncoder(float* data)
: TypedIterator(data) {}
- void Set(bool right) override
+ void Set(float right) override
{
*m_Iterator = right;
}
};
-class QSymm16Encoder : public TypedIterator<int16_t, Encoder>
+class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
{
public:
- QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
- : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+ BooleanEncoder(uint8_t* data)
+ : TypedIterator(data) {}
- void Set(const float& right) override
+ void Set(bool right) override
{
- *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
+ *m_Iterator = right;
}
-
-private:
- const float m_Scale;
- const int32_t m_Offset;
};
+
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 4ff2466e87..e94b031060 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -16,10 +16,12 @@ list(APPEND armnnRefBackendWorkloads_sources
ConvImpl.hpp
Debug.cpp
Debug.hpp
+ Decoders.hpp
DetectionPostProcess.cpp
DetectionPostProcess.hpp
ElementwiseFunction.cpp
ElementwiseFunction.hpp
+ Encoders.hpp
FullyConnected.cpp
FullyConnected.hpp
Gather.cpp
@@ -44,8 +46,6 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchToSpaceNdFloat32Workload.hpp
RefBatchToSpaceNdUint8Workload.cpp
RefBatchToSpaceNdUint8Workload.hpp
- RefComparisonWorkload.cpp
- RefComparisonWorkload.hpp
RefConstantWorkload.cpp
RefConstantWorkload.hpp
RefConvertFp16ToFp32Workload.cpp
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
new file mode 100644
index 0000000000..4112e7d454
--- /dev/null
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -0,0 +1,48 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+namespace armnn
+{
+
+template<typename T>
+std::unique_ptr<Decoder<T>> MakeDecoder(const TensorInfo& info, const void* data);
+
+template<>
+std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const void* data)
+{
+ switch(info.GetDataType())
+ {
+ case armnn::DataType::QuantisedAsymm8:
+ {
+ return std::make_unique<QASymm8Decoder>(
+ static_cast<const uint8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
+ case armnn::DataType::QuantisedSymm16:
+ {
+ return std::make_unique<QSymm16Decoder>(
+ static_cast<const int16_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
+ case armnn::DataType::Float32:
+ {
+ return std::make_unique<FloatDecoder>(static_cast<const float*>(data));
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Not supported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 934a86217a..7a5c071f70 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -13,26 +13,25 @@
namespace armnn
{
-template <typename Functor, typename DecoderOp, typename EncoderOp>
-ElementwiseFunction<Functor, DecoderOp, EncoderOp>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- DecoderOp& inData0,
- DecoderOp& inData1,
- EncoderOp& outData)
+template <typename Functor>
+ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ armnn::Decoder<InType>& inData0,
+ armnn::Decoder<InType>& inData1,
+ armnn::Encoder<OutType>& outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
} //namespace armnn
-template struct armnn::ElementwiseFunction<std::plus<float>, armnn::Decoder, armnn::Encoder>;
-template struct armnn::ElementwiseFunction<std::minus<float>, armnn::Decoder, armnn::Encoder>;
-template struct armnn::ElementwiseFunction<std::multiplies<float>, armnn::Decoder, armnn::Encoder>;
-template struct armnn::ElementwiseFunction<std::divides<float>, armnn::Decoder, armnn::Encoder>;
-template struct armnn::ElementwiseFunction<armnn::maximum<float>, armnn::Decoder, armnn::Encoder>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>, armnn::Decoder, armnn::Encoder>;
-
-template struct armnn::ElementwiseFunction<std::equal_to<float>, armnn::Decoder, armnn::ComparisonEncoder>;
-template struct armnn::ElementwiseFunction<std::greater<float>, armnn::Decoder, armnn::ComparisonEncoder>;
+template struct armnn::ElementwiseFunction<std::plus<float>>;
+template struct armnn::ElementwiseFunction<std::minus<float>>;
+template struct armnn::ElementwiseFunction<std::multiplies<float>>;
+template struct armnn::ElementwiseFunction<std::divides<float>>;
+template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
+template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseFunction<std::equal_to<float>>;
+template struct armnn::ElementwiseFunction<std::greater<float>>;
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index 9eb003d5f9..fd1fab0690 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -11,15 +11,18 @@
namespace armnn
{
-template <typename Functor, typename DecoderOp, typename EncoderOp>
+template <typename Functor>
struct ElementwiseFunction
{
+ using OutType = typename Functor::result_type;
+ using InType = typename Functor::first_argument_type;
+
ElementwiseFunction(const TensorShape& inShape0,
const TensorShape& inShape1,
const TensorShape& outShape,
- DecoderOp& inData0,
- DecoderOp& inData1,
- EncoderOp& outData);
+ armnn::Decoder<InType>& inData0,
+ armnn::Decoder<InType>& inData1,
+ armnn::Encoder<OutType>& outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
new file mode 100644
index 0000000000..90300aa0f7
--- /dev/null
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -0,0 +1,66 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+namespace armnn
+{
+
+template<typename T>
+std::unique_ptr<Encoder<T>> MakeEncoder(const TensorInfo& info, void* data);
+
+template<>
+std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* data)
+{
+ switch(info.GetDataType())
+ {
+ case armnn::DataType::QuantisedAsymm8:
+ {
+ return std::make_unique<QASymm8Encoder>(
+ static_cast<uint8_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
+ case armnn::DataType::QuantisedSymm16:
+ {
+ return std::make_unique<QSymm16Encoder>(
+ static_cast<int16_t*>(data),
+ info.GetQuantizationScale(),
+ info.GetQuantizationOffset());
+ }
+ case armnn::DataType::Float32:
+ {
+ return std::make_unique<FloatEncoder>(static_cast<float*>(data));
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Cannot encode from float. Not supported target Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
+template<>
+std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void* data)
+{
+ switch(info.GetDataType())
+ {
+ case armnn::DataType::Boolean:
+ {
+ return std::make_unique<BooleanEncoder>(static_cast<uint8_t*>(data));
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Cannot encode from boolean. Not supported target Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/Maximum.hpp b/src/backends/reference/workloads/Maximum.hpp
index 524afffc44..97df19d3c6 100644
--- a/src/backends/reference/workloads/Maximum.hpp
+++ b/src/backends/reference/workloads/Maximum.hpp
@@ -10,7 +10,7 @@
namespace armnn
{
template<typename T>
- struct maximum
+struct maximum : public std::binary_function<T, T, T>
{
T
operator () (const T& inputData0, const T& inputData1) const
diff --git a/src/backends/reference/workloads/Minimum.hpp b/src/backends/reference/workloads/Minimum.hpp
index 2f3bdc1c02..0c053981a0 100644
--- a/src/backends/reference/workloads/Minimum.hpp
+++ b/src/backends/reference/workloads/Minimum.hpp
@@ -9,7 +9,7 @@ namespace armnn
{
template<typename T>
-struct minimum
+struct minimum : public std::binary_function<T, T, T>
{
T
operator()(const T& input1, const T& input2) const
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
deleted file mode 100644
index bb8bb04ad3..0000000000
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ /dev/null
@@ -1,76 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "RefComparisonWorkload.hpp"
-#include "ElementwiseFunction.hpp"
-#include "RefWorkloadUtils.hpp"
-#include "Profiling.hpp"
-#include <vector>
-
-namespace armnn {
-
-template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
-void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
- const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
- const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
-
- const TensorShape& inShape0 = inputInfo0.GetShape();
- const TensorShape& inShape1 = inputInfo1.GetShape();
- const TensorShape& outShape = outputInfo.GetShape();
-
- switch(inputInfo0.GetDataType())
- {
- case armnn::DataType::QuantisedAsymm8:
- {
- QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
- inputInfo0.GetQuantizationScale(),
- inputInfo0.GetQuantizationOffset());
-
- QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
- inputInfo1.GetQuantizationScale(),
- inputInfo1.GetQuantizationOffset());
-
- BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
-
- ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
- inShape1,
- outShape,
- decodeIterator0,
- decodeIterator1,
- encodeIterator0);
- break;
- }
- case armnn::DataType::Float32:
- {
- FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
- FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
- BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
-
- ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
- inShape1,
- outShape,
- decodeIterator0,
- decodeIterator1,
- encodeIterator0);
- break;
- }
- default:
- BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!");
- break;
- }
-}
-
-}
-
-template class armnn::RefComparisonWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-template class armnn::RefComparisonWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
deleted file mode 100644
index cfc2dcf2aa..0000000000
--- a/src/backends/reference/workloads/RefComparisonWorkload.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Types.hpp>
-#include <backendsCommon/Workload.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-#include "StringMapping.hpp"
-
-namespace armnn
-{
-
-template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
-class RefComparisonWorkload : public BaseWorkload<ParentDescriptor>
-{
-public:
- using BaseWorkload<ParentDescriptor>::m_Data;
- using BaseWorkload<ParentDescriptor>::BaseWorkload;
-
- void Execute() const override;
-};
-
-using RefEqualWorkload =
- RefComparisonWorkload<std::equal_to<float>,
- EqualQueueDescriptor,
- StringMapping::RefEqualWorkload_Execute>;
-
-
-using RefGreaterWorkload =
- RefComparisonWorkload<std::greater<float>,
- GreaterQueueDescriptor,
- StringMapping::RefGreaterWorkload_Execute>;
-} // armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 1a30e7c9fb..535adca0d7 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -4,17 +4,41 @@
//
#include "RefElementwiseWorkload.hpp"
+
+#include "Decoders.hpp"
#include "ElementwiseFunction.hpp"
-#include "RefWorkloadUtils.hpp"
+#include "Encoders.hpp"
#include "Profiling.hpp"
+#include "RefWorkloadUtils.hpp"
#include "StringMapping.hpp"
#include "TypeUtils.hpp"
+
#include <vector>
namespace armnn
{
template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWorkload(
+ const ParentDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ParentDescriptor>(desc, info)
+{
+}
+
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input0 = MakeDecoder<InType>(inputInfo0, m_Data.m_Inputs[0]->Map());
+ m_Input1 = MakeDecoder<InType>(inputInfo1, m_Data.m_Inputs[1]->Map());
+ m_Output = MakeEncoder<OutType>(outputInfo, m_Data.m_Outputs[0]->Map());
+}
+
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
@@ -26,73 +50,15 @@ void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() c
const TensorShape& inShape1 = inputInfo1.GetShape();
const TensorShape& outShape = outputInfo.GetShape();
- switch(inputInfo0.GetDataType())
- {
- case armnn::DataType::QuantisedAsymm8:
- {
- QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
- inputInfo0.GetQuantizationScale(),
- inputInfo0.GetQuantizationOffset());
-
- QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
- inputInfo1.GetQuantizationScale(),
- inputInfo1.GetQuantizationOffset());
-
- QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data),
- outputInfo.GetQuantizationScale(),
- outputInfo.GetQuantizationOffset());
-
- ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
- inShape1,
- outShape,
- decodeIterator0,
- decodeIterator1,
- encodeIterator0);
- break;
- }
- case armnn::DataType::Float32:
- {
- FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
- FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
- FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data));
-
- ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
- inShape1,
- outShape,
- decodeIterator0,
- decodeIterator1,
- encodeIterator0);
- break;
- }
- case armnn::DataType::QuantisedSymm16:
- {
- QSymm16Decoder decodeIterator0(GetInputTensorData<int16_t>(0, m_Data),
- inputInfo0.GetQuantizationScale(),
- inputInfo0.GetQuantizationOffset());
-
- QSymm16Decoder decodeIterator1(GetInputTensorData<int16_t>(1, m_Data),
- inputInfo1.GetQuantizationScale(),
- inputInfo1.GetQuantizationOffset());
-
- QSymm16Encoder encodeIterator0(GetOutputTensorData<int16_t>(0, m_Data),
- outputInfo.GetQuantizationScale(),
- outputInfo.GetQuantizationOffset());
-
- ElementwiseFunction<Functor, Decoder, Encoder>(inShape0,
- inShape1,
- outShape,
- decodeIterator0,
- decodeIterator1,
- encodeIterator0);
- break;
- }
- default:
- BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!");
- break;
- }
+ ElementwiseFunction<Functor>(inShape0,
+ inShape1,
+ outShape,
+ *m_Input0,
+ *m_Input1,
+ *m_Output);
}
-}
+} //namespace armnn
template class armnn::RefElementwiseWorkload<std::plus<float>,
armnn::AdditionQueueDescriptor,
@@ -116,4 +82,12 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
- armnn::StringMapping::RefMinimumWorkload_Execute>; \ No newline at end of file
+ armnn::StringMapping::RefMinimumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::equal_to<float>,
+ armnn::EqualQueueDescriptor,
+ armnn::StringMapping::RefEqualWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<std::greater<float>,
+ armnn::GreaterQueueDescriptor,
+ armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 81af19627e..651942e9e5 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -8,6 +8,8 @@
#include <armnn/Types.hpp>
#include <backendsCommon/Workload.hpp>
#include <backendsCommon/WorkloadData.hpp>
+#include "BaseIterator.hpp"
+#include "ElementwiseFunction.hpp"
#include "Maximum.hpp"
#include "Minimum.hpp"
#include "StringMapping.hpp"
@@ -19,10 +21,18 @@ template <typename Functor, typename ParentDescriptor, typename armnn::StringMap
class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
{
public:
+ using InType = typename ElementwiseFunction<Functor>::InType;
+ using OutType = typename ElementwiseFunction<Functor>::OutType;
using BaseWorkload<ParentDescriptor>::m_Data;
- using BaseWorkload<ParentDescriptor>::BaseWorkload;
+ RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
void Execute() const override;
+
+private:
+ std::unique_ptr<Decoder<InType>> m_Input0;
+ std::unique_ptr<Decoder<InType>> m_Input1;
+ std::unique_ptr<Encoder<OutType>> m_Output;
};
using RefAdditionWorkload =
@@ -54,4 +64,14 @@ using RefMinimumWorkload =
RefElementwiseWorkload<armnn::minimum<float>,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
+
+using RefEqualWorkload =
+ RefElementwiseWorkload<std::equal_to<float>,
+ armnn::EqualQueueDescriptor,
+ armnn::StringMapping::RefEqualWorkload_Execute>;
+
+using RefGreaterWorkload =
+ RefElementwiseWorkload<std::greater<float>,
+ armnn::GreaterQueueDescriptor,
+ armnn::StringMapping::RefGreaterWorkload_Execute>;
} // armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 77aa56fcc6..a1b584759b 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -62,7 +62,6 @@
#include "RefBatchToSpaceNdFloat32Workload.hpp"
#include "RefDebugWorkload.hpp"
#include "RefRsqrtFloat32Workload.hpp"
-#include "RefComparisonWorkload.hpp"
#include "RefDequantizeWorkload.hpp"
#include "RefQuantizeWorkload.hpp" \ No newline at end of file