aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-11 16:37:14 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-01 14:56:47 +0100
commit591cdb7ad9176d43163b340542b0fff470a198ea (patch)
tree67b3a910f85a4ae3850fea34c893525c100feae1
parent549ae37b51facbdf940bff62b45b3b74c1bc63c9 (diff)
downloadarmnn-591cdb7ad9176d43163b340542b0fff470a198ea.tar.gz
IVGCVSW-1843 : replacing trivial arithmetic helpers
Change-Id: Iddf637694f1a3a7ef00f006a41b8044a35c7e73c
-rw-r--r--Android.mk5
-rw-r--r--CMakeLists.txt10
-rw-r--r--src/armnn/backends/RefWorkloads.hpp3
-rw-r--r--src/armnn/backends/RefWorkloads/Addition.cpp44
-rw-r--r--src/armnn/backends/RefWorkloads/Addition.hpp20
-rw-r--r--src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp29
-rw-r--r--src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp24
-rw-r--r--src/armnn/backends/RefWorkloads/Division.cpp89
-rw-r--r--src/armnn/backends/RefWorkloads/Division.hpp20
-rw-r--r--src/armnn/backends/RefWorkloads/Multiplication.cpp52
-rw-r--r--src/armnn/backends/RefWorkloads/Multiplication.hpp20
-rw-r--r--src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp4
-rw-r--r--src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp14
-rw-r--r--src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp4
-rw-r--r--src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp12
-rw-r--r--src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp4
-rw-r--r--src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp12
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp4
-rw-r--r--src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp14
-rw-r--r--src/armnn/backends/RefWorkloads/Subtraction.cpp44
-rw-r--r--src/armnn/backends/RefWorkloads/Subtraction.hpp20
21 files changed, 95 insertions, 353 deletions
diff --git a/Android.mk b/Android.mk
index 9c2373678d..9c4db74d1a 100644
--- a/Android.mk
+++ b/Android.mk
@@ -128,16 +128,14 @@ LOCAL_SRC_FILES := \
src/armnn/backends/RefWorkloads/RefSoftmaxFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefActivationFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp \
- src/armnn/backends/RefWorkloads/Multiplication.cpp \
src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp \
src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/Broadcast.cpp \
- src/armnn/backends/RefWorkloads/Addition.cpp \
+ src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp \
src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
- src/armnn/backends/RefWorkloads/Subtraction.cpp \
src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp \
src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp \
@@ -170,7 +168,6 @@ LOCAL_SRC_FILES := \
src/armnn/backends/RefWorkloads/RefPermuteWorkload.cpp \
src/armnn/backends/RefWorkloads/RefConvertFp16ToFp32Workload.cpp \
src/armnn/backends/RefWorkloads/RefConvertFp32ToFp16Workload.cpp \
- src/armnn/backends/RefWorkloads/Division.cpp \
src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp \
src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp \
src/armnn/backends/MemCopyWorkload.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 777c3153e6..9c2685c96d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -186,14 +186,12 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/Broadcast.cpp
src/armnn/backends/RefWorkloads/RefMergerUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefConstantUint8Workload.hpp
- src/armnn/backends/RefWorkloads/Addition.cpp
- src/armnn/backends/RefWorkloads/Addition.hpp
+ src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp
+ src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp
src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
- src/armnn/backends/RefWorkloads/Subtraction.cpp
- src/armnn/backends/RefWorkloads/Subtraction.hpp
src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
@@ -210,12 +208,8 @@ list(APPEND armnn_sources
src/armnn/backends/RefWorkloads/RefActivationFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefResizeBilinearUint8Workload.hpp
- src/armnn/backends/RefWorkloads/Multiplication.cpp
- src/armnn/backends/RefWorkloads/Division.cpp
- src/armnn/backends/RefWorkloads/Division.hpp
src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefL2NormalizationFloat32Workload.hpp
- src/armnn/backends/RefWorkloads/Multiplication.hpp
src/armnn/backends/RefWorkloads/RefActivationUint8Workload.hpp
src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp
src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp
diff --git a/src/armnn/backends/RefWorkloads.hpp b/src/armnn/backends/RefWorkloads.hpp
index 910610c72e..e58d4accbb 100644
--- a/src/armnn/backends/RefWorkloads.hpp
+++ b/src/armnn/backends/RefWorkloads.hpp
@@ -6,7 +6,7 @@
#pragma once
#include "backends/RefWorkloads/RefConstantUint8Workload.hpp"
-#include "backends/RefWorkloads/Addition.hpp"
+#include "backends/RefWorkloads/ArithmeticFunction.hpp"
#include "backends/RefWorkloads/ConvImpl.hpp"
#include "backends/RefWorkloads/RefMultiplicationUint8Workload.hpp"
#include "backends/RefWorkloads/RefBaseConstantWorkload.hpp"
@@ -14,7 +14,6 @@
#include "backends/RefWorkloads/RefSplitterUint8Workload.hpp"
#include "backends/RefWorkloads/RefResizeBilinearUint8Workload.hpp"
#include "backends/RefWorkloads/RefL2NormalizationFloat32Workload.hpp"
-#include "backends/RefWorkloads/Multiplication.hpp"
#include "backends/RefWorkloads/RefActivationUint8Workload.hpp"
#include "backends/RefWorkloads/RefPooling2dFloat32Workload.hpp"
#include "backends/RefWorkloads/RefWorkloadUtils.hpp"
diff --git a/src/armnn/backends/RefWorkloads/Addition.cpp b/src/armnn/backends/RefWorkloads/Addition.cpp
deleted file mode 100644
index 33d5bd538f..0000000000
--- a/src/armnn/backends/RefWorkloads/Addition.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Addition.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseAddition(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
-{
- for (unsigned int i = 0; i < numElements; ++i)
- {
- outData[i] = inData0[i] + inData1[i];
- }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Addition(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- if (inShape0 == inShape1)
- {
- ElementwiseAddition(inShape0.GetNumElements(), inData0, inData1, outData);
- }
- else
- {
- BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::plus<float>(), 0, inData0, inData1, outData);
- }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Addition.hpp b/src/armnn/backends/RefWorkloads/Addition.hpp
deleted file mode 100644
index dcbd499eeb..0000000000
--- a/src/armnn/backends/RefWorkloads/Addition.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Addition(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp b/src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp
new file mode 100644
index 0000000000..fede138253
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ArithmeticFunction.hpp"
+#include "Broadcast.hpp"
+#include <functional>
+
+namespace armnn
+{
+
+template <typename Functor>
+ArithmeticFunction<Functor>::ArithmeticFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
+{
+ BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
+}
+
+} //namespace armnn
+
+template struct armnn::ArithmeticFunction<std::plus<float>>;
+template struct armnn::ArithmeticFunction<std::minus<float>>;
+template struct armnn::ArithmeticFunction<std::multiplies<float>>;
+template struct armnn::ArithmeticFunction<std::divides<float>>;
diff --git a/src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp b/src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp
new file mode 100644
index 0000000000..eafb6444f6
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp
@@ -0,0 +1,24 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+namespace armnn
+{
+
+template <typename Functor>
+struct ArithmeticFunction
+{
+ ArithmeticFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData);
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Division.cpp b/src/armnn/backends/RefWorkloads/Division.cpp
deleted file mode 100644
index cc7f7c9fe4..0000000000
--- a/src/armnn/backends/RefWorkloads/Division.cpp
+++ /dev/null
@@ -1,89 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Division.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-#include <cmath>
-
-namespace
-{
-
-void ElementwiseDivision(unsigned int numElements,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- for (unsigned int i = 0; i < numElements; ++i)
- {
- if (inData1[i] != 0.0f)
- {
- outData[i] = inData0[i] / inData1[i];
- }
- else if (inData0[i] == 0.0f)
- {
- if (!std::signbit(inData1[i]))
- {
- outData[i]= NAN;
- }
- else
- {
- outData[i]= -NAN;
- }
- }
- else if (inData0[i] < 0.0f)
- {
- if (!std::signbit(inData1[i]))
- {
- outData[i] = -INFINITY;
- }
- else
- {
- outData[i] = INFINITY;
- }
- }
- else
- {
- if (!std::signbit(inData1[i]))
- {
- outData[i] = INFINITY;
- }
- else
- {
- outData[i] = -INFINITY;
- }
- }
- }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Division(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- if (inShape0 == inShape1)
- {
- ElementwiseDivision(inShape0.GetNumElements(), inData0, inData1, outData);
- }
- else
- {
- BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::divides<float>(),
- 0,
- inData0,
- inData1,
- outData);
- }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Division.hpp b/src/armnn/backends/RefWorkloads/Division.hpp
deleted file mode 100644
index b83c77f796..0000000000
--- a/src/armnn/backends/RefWorkloads/Division.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
- void Division(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.cpp b/src/armnn/backends/RefWorkloads/Multiplication.cpp
deleted file mode 100644
index ae6446af97..0000000000
--- a/src/armnn/backends/RefWorkloads/Multiplication.cpp
+++ /dev/null
@@ -1,52 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Multiplication.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseMultiplication(unsigned int numElements,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- for (unsigned int i = 0; i < numElements; ++i)
- {
- outData[i] = inData0[i] * inData1[i];
- }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Multiplication(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- if (inShape0 == inShape1)
- {
- ElementwiseMultiplication(inShape0.GetNumElements(), inData0, inData1, outData);
- }
- else
- {
- BroadcastLoop(inShape0, inShape1, outShape).Unroll(
- std::multiplies<float>(),
- 0,
- inData0,
- inData1,
- outData);
- }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.hpp b/src/armnn/backends/RefWorkloads/Multiplication.hpp
deleted file mode 100644
index 58ad7b4cad..0000000000
--- a/src/armnn/backends/RefWorkloads/Multiplication.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Multiplication(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
index c2a5b5fcbd..21c7533c0f 100644
--- a/src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
@@ -5,7 +5,7 @@
#include "RefAdditionFloat32Workload.hpp"
-#include "Addition.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefAdditionFloat32Workload::Execute() const
const float* inData1 = GetInputTensorDataFloat(1, m_Data);
float* outData = GetOutputTensorDataFloat(0, m_Data);
- Addition(inShape0, inShape1, outShape, inData0, inData1, outData);
+ ArithmeticFunction<std::plus<float>>(inShape0, inShape1, outShape, inData0, inData1, outData);
}
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
index 2999be9240..116a5f14cb 100644
--- a/src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
@@ -5,7 +5,7 @@
#include "RefAdditionUint8Workload.hpp"
-#include "Addition.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -28,12 +28,12 @@ void RefAdditionUint8Workload::Execute() const
std::vector<float> results(outputInfo.GetNumElements());
- Addition(inputInfo0.GetShape(),
- inputInfo1.GetShape(),
- outputInfo.GetShape(),
- dequant0.data(),
- dequant1.data(),
- results.data());
+ ArithmeticFunction<std::plus<float>>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}
diff --git a/src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp
index 81f4645cbc..28c90610de 100644
--- a/src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp
@@ -5,7 +5,7 @@
#include "RefDivisionFloat32Workload.hpp"
-#include "Division.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefDivisionFloat32Workload::Execute() const
const float* inputData0 = GetInputTensorDataFloat(0, m_Data);
const float* inputData1 = GetInputTensorDataFloat(1, m_Data);
- Division(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
+ ArithmeticFunction<std::divides<float>>(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
}
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp
index a6ed770c40..d10d874137 100644
--- a/src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp
@@ -5,7 +5,7 @@
#include "RefDivisionUint8Workload.hpp"
-#include "Division.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -27,9 +27,13 @@ void RefDivisionUint8Workload::Execute() const
auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
std::vector<float> results(outputInfo.GetNumElements());
- Division(
- inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
- dequant0.data(), dequant1.data(),results.data());
+
+ ArithmeticFunction<std::divides<float>>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
index 022cca70e7..0b36f0ff00 100644
--- a/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
@@ -5,7 +5,7 @@
#include "RefMultiplicationFloat32Workload.hpp"
-#include "Multiplication.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefMultiplicationFloat32Workload::Execute() const
const float* inputData0 = GetInputTensorDataFloat(0, m_Data);
const float* inputData1 = GetInputTensorDataFloat(1, m_Data);
- Multiplication(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
+ ArithmeticFunction<std::multiplies<float>>(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
}
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
index 8e0a617bf5..b929a53808 100644
--- a/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
@@ -5,7 +5,7 @@
#include "RefMultiplicationUint8Workload.hpp"
-#include "Multiplication.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -27,9 +27,13 @@ void RefMultiplicationUint8Workload::Execute() const
auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
std::vector<float> results(outputInfo.GetNumElements());
- Multiplication(
- inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
- dequant0.data(), dequant1.data(),results.data());
+
+ ArithmeticFunction<std::multiplies<float>>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
index 4440eedab7..f1840c347b 100644
--- a/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
@@ -5,7 +5,7 @@
#include "RefSubtractionFloat32Workload.hpp"
-#include "Subtraction.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefSubtractionFloat32Workload::Execute() const
const float* inData1 = GetInputTensorDataFloat(1, m_Data);
float* outData = GetOutputTensorDataFloat(0, m_Data);
- Subtraction(inShape0, inShape1, outShape, inData0, inData1, outData);
+ ArithmeticFunction<std::minus<float>>(inShape0, inShape1, outShape, inData0, inData1, outData);
}
} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
index 8066762e48..1affbdd8b1 100644
--- a/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
@@ -5,7 +5,7 @@
#include "RefSubtractionUint8Workload.hpp"
-#include "Subtraction.hpp"
+#include "ArithmeticFunction.hpp"
#include "RefWorkloadUtils.hpp"
#include "Profiling.hpp"
@@ -28,12 +28,12 @@ void RefSubtractionUint8Workload::Execute() const
std::vector<float> results(outputInfo.GetNumElements());
- Subtraction(inputInfo0.GetShape(),
- inputInfo1.GetShape(),
- outputInfo.GetShape(),
- dequant0.data(),
- dequant1.data(),
- results.data());
+ ArithmeticFunction<std::minus<float>>(inputInfo0.GetShape(),
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
}
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.cpp b/src/armnn/backends/RefWorkloads/Subtraction.cpp
deleted file mode 100644
index f25c8adb1c..0000000000
--- a/src/armnn/backends/RefWorkloads/Subtraction.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Subtraction.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseSubtraction(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
-{
- for (unsigned int i = 0; i < numElements; ++i)
- {
- outData[i] = inData0[i] - inData1[i];
- }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Subtraction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
-{
- if (inShape0 == inShape1)
- {
- ElementwiseSubtraction(inShape0.GetNumElements(), inData0, inData1, outData);
- }
- else
- {
- BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::minus<float>(), 0, inData0, inData1, outData);
- }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.hpp b/src/armnn/backends/RefWorkloads/Subtraction.hpp
deleted file mode 100644
index 3956797185..0000000000
--- a/src/armnn/backends/RefWorkloads/Subtraction.hpp
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Subtraction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
-
-} //namespace armnn