aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp7
-rw-r--r--src/backends/reference/workloads/Reduce.cpp43
2 files changed, 33 insertions, 17 deletions
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index e906b2962c..18490e29c7 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -2321,6 +2321,13 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceSumSingleAxisFloat32_2, ReduceSumSingleAxisT
ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceSumSingleAxisFloat32_3, ReduceSumSingleAxisTest3<DataType::Float32>)
ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceSumMultipleAxisFloat32, ReduceSumMultipleAxisTest<DataType::Float32>)
+// ReduceProd
+ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceProdFloat32, ReduceProdSimpleTest<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceProdSingleAxisFloat32_1, ReduceProdSingleAxisTest1<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceProdSingleAxisFloat32_2, ReduceProdSingleAxisTest2<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceProdSingleAxisFloat32_3, ReduceProdSingleAxisTest3<DataType::Float32>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceProdMultipleAxisFloat32, ReduceProdMultipleAxisTest<DataType::Float32>)
+
// ReduceMax
ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMaxFloat32, ReduceMaxSimpleTest<DataType::Float32>)
ARMNN_AUTO_TEST_CASE_WITH_THF(ReduceMaxNegativeAxisFloat32, ReduceMaxNegativeAxisTest<DataType::Float32>)
diff --git a/src/backends/reference/workloads/Reduce.cpp b/src/backends/reference/workloads/Reduce.cpp
index 8bf422aea3..3f929c43bc 100644
--- a/src/backends/reference/workloads/Reduce.cpp
+++ b/src/backends/reference/workloads/Reduce.cpp
@@ -9,7 +9,6 @@
#include <backendsCommon/WorkloadData.hpp>
-#include <cmath>
#include <cstddef>
#include <functional>
#include <limits>
@@ -87,6 +86,9 @@ void Reduce(const TensorInfo& inputInfo,
case ReduceOperation::Sum:
std::fill(tempOut.begin(), tempOut.end(), 0.0f);
break;
+ case ReduceOperation::Prod:
+ std::fill(tempOut.begin(), tempOut.end(), 1.0f);
+ break;
case ReduceOperation::Max:
std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
break;
@@ -119,23 +121,30 @@ void Reduce(const TensorInfo& inputInfo,
numResolvedAxis, resolvedAxis);
input[inputOffset];
auto inputValue = input.Get();
- if (reduceOperation == ReduceOperation::Max)
- {
- if (inputValue > tempOut[outputOffset])
- {
- tempOut[outputOffset] = inputValue;
- }
- }
- else if (reduceOperation == ReduceOperation::Min)
- {
- if (inputValue < tempOut[outputOffset])
- {
- tempOut[outputOffset] = inputValue;
- }
- }
- else
+ switch(reduceOperation)
{
- tempOut[outputOffset] += inputValue;
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ tempOut[outputOffset] += inputValue;
+ break;
+ case ReduceOperation::Prod:
+ tempOut[outputOffset] *= inputValue;
+ break;
+ case ReduceOperation::Max:
+ if (inputValue > tempOut[outputOffset])
+ {
+ tempOut[outputOffset] = inputValue;
+ }
+ break;
+ case ReduceOperation::Min:
+ if (inputValue < tempOut[outputOffset])
+ {
+ tempOut[outputOffset] = inputValue;
+ }
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
}