aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads')
-rw-r--r--src/armnn/backends/RefWorkloads/Addition.cpp6
-rw-r--r--src/armnn/backends/RefWorkloads/Merger.hpp1
-rw-r--r--src/armnn/backends/RefWorkloads/Multiplication.cpp42
-rw-r--r--src/armnn/backends/RefWorkloads/Multiplication.hpp12
-rw-r--r--src/armnn/backends/RefWorkloads/Pooling2d.cpp4
-rw-r--r--src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp7
-rw-r--r--src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp7
-rw-r--r--src/armnn/backends/RefWorkloads/Splitter.hpp1
8 files changed, 59 insertions, 21 deletions
diff --git a/src/armnn/backends/RefWorkloads/Addition.cpp b/src/armnn/backends/RefWorkloads/Addition.cpp
index c26f82ecc2..6d53a702e4 100644
--- a/src/armnn/backends/RefWorkloads/Addition.cpp
+++ b/src/armnn/backends/RefWorkloads/Addition.cpp
@@ -8,9 +8,6 @@
#include <functional>
-namespace armnn
-{
-
namespace
{
@@ -24,6 +21,9 @@ void ElementwiseAddition(unsigned int numElements, const float* inData0, const f
} // namespace
+namespace armnn
+{
+
void Addition(const TensorShape& inShape0,
const TensorShape& inShape1,
const TensorShape& outShape,
diff --git a/src/armnn/backends/RefWorkloads/Merger.hpp b/src/armnn/backends/RefWorkloads/Merger.hpp
index 9695e457e2..476ced76be 100644
--- a/src/armnn/backends/RefWorkloads/Merger.hpp
+++ b/src/armnn/backends/RefWorkloads/Merger.hpp
@@ -39,6 +39,7 @@ void Merger(const MergerQueueDescriptor& data)
//split view extents are defined by the size of (the corresponding) input tensor
const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[viewIdx]);
+ BOOST_ASSERT(inputInfo.GetNumDimensions() == outputInfo0.GetNumDimensions());
// check all dimensions to see if this element is inside the given input view
bool insideView = true;
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.cpp b/src/armnn/backends/RefWorkloads/Multiplication.cpp
index 7f558d83c5..47c0f1cef1 100644
--- a/src/armnn/backends/RefWorkloads/Multiplication.cpp
+++ b/src/armnn/backends/RefWorkloads/Multiplication.cpp
@@ -4,18 +4,48 @@
//
#include "Multiplication.hpp"
+#include "Broadcast.hpp"
-namespace armnn
+#include <functional>
+
+namespace
{
-void Multiplication(const float* in0,
- const float* in1,
- unsigned int numElements,
- float* out)
+void ElementwiseMultiplication(unsigned int numElements,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
{
for (unsigned int i = 0; i < numElements; ++i)
{
- out[i] = in0[i] * in1[i];
+ outData[i] = inData0[i] * inData1[i];
+ }
+}
+
+} // namespace
+
+namespace armnn
+{
+
+void Multiplication(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
+{
+ if (inShape0 == inShape1)
+ {
+ ElementwiseMultiplication(inShape0.GetNumElements(), inData0, inData1, outData);
+ }
+ else
+ {
+ BroadcastLoop(inShape0, inShape1, outShape).Unroll(
+ std::multiplies<float>(),
+ 0,
+ inData0,
+ inData1,
+ outData);
}
}
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.hpp b/src/armnn/backends/RefWorkloads/Multiplication.hpp
index d0b033e7ec..54fcac51c1 100644
--- a/src/armnn/backends/RefWorkloads/Multiplication.hpp
+++ b/src/armnn/backends/RefWorkloads/Multiplication.hpp
@@ -5,12 +5,16 @@
#pragma once
+#include <armnn/Tensor.hpp>
+
namespace armnn
{
-void Multiplication(const float* in0,
- const float* in1,
- unsigned int numElements,
- float* out);
+void Multiplication(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData);
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Pooling2d.cpp b/src/armnn/backends/RefWorkloads/Pooling2d.cpp
index 6d15d8a436..a643e67690 100644
--- a/src/armnn/backends/RefWorkloads/Pooling2d.cpp
+++ b/src/armnn/backends/RefWorkloads/Pooling2d.cpp
@@ -186,8 +186,8 @@ void Pooling2d(const float* in,
// Clamp the pooling region inside the valid input area (which includes the padding).
// This is necessary because the final pooling in a row may overlap beyond the padding.
- hend = std::min(hend, heightInput + padRight);
- wend = std::min(wend, widthInput + padBottom);
+ hend = std::min(hend, heightInput + padBottom);
+ wend = std::min(wend, widthInput + padRight);
float result = defaultInitializer;
float poolAreaSize = boost::numeric_cast<float>((hend - hstart) * (wend - wstart));
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
index ed68b1f6db..d7c54d9aad 100644
--- a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
@@ -17,12 +17,15 @@ void RefMultiplicationFloat32Workload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefMultiplicationFloat32Workload_Execute");
- const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorShape& inShape0 = GetTensorInfo(m_Data.m_Inputs[0]).GetShape();
+ const TensorShape& inShape1 = GetTensorInfo(m_Data.m_Inputs[1]).GetShape();
+ const TensorShape& outShape = GetTensorInfo(m_Data.m_Outputs[0]).GetShape();
float* outputData = GetOutputTensorDataFloat(0, m_Data);
const float* inputData0 = GetInputTensorDataFloat(0, m_Data);
const float* inputData1 = GetInputTensorDataFloat(1, m_Data);
- Multiplication(inputData0, inputData1, inputInfo0.GetNumElements(), outputData);
+
+ Multiplication(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
}
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
index 2e6f0e6c8b..d5c4afd87c 100644
--- a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
@@ -27,10 +27,9 @@ void RefMultiplicationUint8Workload::Execute() const
auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
std::vector<float> results(outputInfo.GetNumElements());
- Multiplication(dequant0.data(),
- dequant1.data(),
- inputInfo0.GetNumElements(),
- results.data());
+ Multiplication(
+ inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
+ dequant0.data(), dequant1.data(),results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}
diff --git a/src/armnn/backends/RefWorkloads/Splitter.hpp b/src/armnn/backends/RefWorkloads/Splitter.hpp
index 67f6c100f9..74c4cb4e18 100644
--- a/src/armnn/backends/RefWorkloads/Splitter.hpp
+++ b/src/armnn/backends/RefWorkloads/Splitter.hpp
@@ -41,6 +41,7 @@ void Splitter(const SplitterQueueDescriptor& data)
//split view extents are defined by the size of (the corresponding) input tensor
const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[viewIdx]);
+ BOOST_ASSERT(outputInfo.GetNumDimensions() == inputInfo0.GetNumDimensions());
// check all dimensions to see if this element is inside the given input view
bool insideView = true;