aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/FullyConnected.cpp
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2019-05-27 12:14:10 +0100
committerFrancis Murtagh <francis.murtagh@arm.com>2019-05-27 12:14:10 +0100
commit43aec5886449c1b024b740fd6f4500e827bde221 (patch)
treec12a128dcc6895a0663a4e4dd27c4110e492c6dd /src/backends/reference/workloads/FullyConnected.cpp
parent7f2c35a82ec11be50b3478bd15207320bbf3bd57 (diff)
downloadarmnn-43aec5886449c1b024b740fd6f4500e827bde221.tar.gz
IVGCVSW-3134 Refactor FullyConnected workloads into single workload
* Refactor FullyConnected workloads into single workload. * Refactor FullyConnected ref implementation to use Encoders and Decoders to support all DataTypes. * Deleted RefFullyConnectedFloat32Workload and RefFullyConnected2dUint8Workload. Change-Id: Iad30fb0287ab7491e1297997e7d61f1d00785541 Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r--src/backends/reference/workloads/FullyConnected.cpp50
1 files changed, 26 insertions, 24 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index bf5814d2ad..02d9b060ef 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -5,32 +5,29 @@
#include "FullyConnected.hpp"
+#include "RefWorkloadUtils.hpp"
+
#include <boost/assert.hpp>
namespace armnn
{
-void FullyConnected(const float* inputData,
- float* outputData,
- const TensorInfo& inputTensorInfo,
- const TensorInfo& outputTensorInfo,
- const float* weightData,
- const float* biasData,
- bool transposeWeights)
+void FullyConnected(const TensorShape& rInputShape,
+ Decoder<float>& rInputDecoder,
+ const TensorShape& rOutputShape,
+ Encoder<float>& rOutputEncoder,
+ Decoder<float>& rWeightDecoder,
+ Decoder<float>& rBiasDecoder,
+ const bool biasEnabled,
+ const unsigned int K,
+ const bool transposeWeights)
{
- unsigned int N = outputTensorInfo.GetShape()[1]; // Outputs Vector Size.
-
- BOOST_ASSERT(inputTensorInfo.GetNumDimensions() > 1); // Needs some data.
-
- unsigned int K = 1; // Total number of activations in the input.
- for (unsigned int i = 1; i < inputTensorInfo.GetNumDimensions(); i++)
- {
- K *= inputTensorInfo.GetShape()[i];
- }
+ // Perform FullyConnected implementation
+ unsigned int outputSize = rOutputShape[1];
- for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
+ for (unsigned int n = 0; n < rInputShape[0]; n++)
{
- for (unsigned int channelOutput = 0; channelOutput < N; channelOutput++)
+ for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
{
float outval = 0.f;
@@ -39,22 +36,27 @@ void FullyConnected(const float* inputData,
float weight;
if (transposeWeights)
{
- weight = weightData[channelOutput * K + channelInput];
+ rWeightDecoder[channelOutput * K + channelInput];
+ weight = rWeightDecoder.Get();
}
else
{
- weight = weightData[channelInput * N + channelOutput];
+ rWeightDecoder[channelInput * outputSize + channelOutput];
+ weight = rWeightDecoder.Get();
}
- outval += weight * inputData[n * K + channelInput];
+ rInputDecoder[n * K + channelInput];
+ outval += weight * rInputDecoder.Get();
}
- if (biasData)
+ if (biasEnabled)
{
- outval += biasData[channelOutput];
+ rBiasDecoder[channelOutput];
+ outval += rBiasDecoder.Get();
}
- outputData[n * N + channelOutput] = outval;
+ rOutputEncoder[n * outputSize + channelOutput];
+ rOutputEncoder.Set(outval);
}
}
}