aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/FullyConnected.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r--src/backends/reference/workloads/FullyConnected.cpp50
1 files changed, 26 insertions, 24 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index bf5814d2ad..02d9b060ef 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -5,32 +5,29 @@
#include "FullyConnected.hpp"
+#include "RefWorkloadUtils.hpp"
+
#include <boost/assert.hpp>
namespace armnn
{
-void FullyConnected(const float* inputData,
- float* outputData,
- const TensorInfo& inputTensorInfo,
- const TensorInfo& outputTensorInfo,
- const float* weightData,
- const float* biasData,
- bool transposeWeights)
+void FullyConnected(const TensorShape& rInputShape,
+ Decoder<float>& rInputDecoder,
+ const TensorShape& rOutputShape,
+ Encoder<float>& rOutputEncoder,
+ Decoder<float>& rWeightDecoder,
+ Decoder<float>& rBiasDecoder,
+ const bool biasEnabled,
+ const unsigned int K,
+ const bool transposeWeights)
{
- unsigned int N = outputTensorInfo.GetShape()[1]; // Outputs Vector Size.
-
- BOOST_ASSERT(inputTensorInfo.GetNumDimensions() > 1); // Needs some data.
-
- unsigned int K = 1; // Total number of activations in the input.
- for (unsigned int i = 1; i < inputTensorInfo.GetNumDimensions(); i++)
- {
- K *= inputTensorInfo.GetShape()[i];
- }
+ // Perform FullyConnected implementation
+ unsigned int outputSize = rOutputShape[1];
- for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
+ for (unsigned int n = 0; n < rInputShape[0]; n++)
{
- for (unsigned int channelOutput = 0; channelOutput < N; channelOutput++)
+ for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
{
float outval = 0.f;
@@ -39,22 +36,27 @@ void FullyConnected(const float* inputData,
float weight;
if (transposeWeights)
{
- weight = weightData[channelOutput * K + channelInput];
+ rWeightDecoder[channelOutput * K + channelInput];
+ weight = rWeightDecoder.Get();
}
else
{
- weight = weightData[channelInput * N + channelOutput];
+ rWeightDecoder[channelInput * outputSize + channelOutput];
+ weight = rWeightDecoder.Get();
}
- outval += weight * inputData[n * K + channelInput];
+ rInputDecoder[n * K + channelInput];
+ outval += weight * rInputDecoder.Get();
}
- if (biasData)
+ if (biasEnabled)
{
- outval += biasData[channelOutput];
+ rBiasDecoder[channelOutput];
+ outval += rBiasDecoder.Get();
}
- outputData[n * N + channelOutput] = outval;
+ rOutputEncoder[n * outputSize + channelOutput];
+ rOutputEncoder.Set(outval);
}
}
}