aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/FullyConnected.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-09-17 15:58:31 +0100
committerfinn.williams <finn.williams@arm.com>2020-09-28 09:01:58 +0000
commitb9dcfe63b87f024c6f8c5f4b68447de04119dc19 (patch)
tree0c58376c59190ecbc8df0dd2abedbf85983d5256 /src/backends/reference/workloads/FullyConnected.cpp
parentbe727becad9fe048480ab53a0281b46594f95ca7 (diff)
downloadarmnn-b9dcfe63b87f024c6f8c5f4b68447de04119dc19.tar.gz
IVGCVSW-5325 Speed up the reference backend
Change-Id: Id8bd0a0418be31d975b944b54bbacb25051ffb2e Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r--src/backends/reference/workloads/FullyConnected.cpp19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index 8016c1b628..61c8e88bce 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -14,6 +14,7 @@ void FullyConnected(const TensorShape& rInputShape,
Decoder<float>& rInputDecoder,
const TensorShape& rOutputShape,
Encoder<float>& rOutputEncoder,
+ const TensorShape& rWeightsShape,
Decoder<float>& rWeightDecoder,
Decoder<float>& rBiasDecoder,
const bool biasEnabled,
@@ -23,6 +24,12 @@ void FullyConnected(const TensorShape& rInputShape,
// Perform FullyConnected implementation
unsigned int outputSize = rOutputShape[1];
+ const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape.GetNumElements());
+ const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape.GetNumElements());
+ const std::vector<float> decodedBiases = biasEnabled ?
+ rBiasDecoder.DecodeTensor(outputSize) : std::vector<float>();
+
+
for (unsigned int n = 0; n < rInputShape[0]; n++)
{
for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
@@ -34,23 +41,19 @@ void FullyConnected(const TensorShape& rInputShape,
float weight;
if (transposeWeights)
{
- rWeightDecoder[channelOutput * K + channelInput];
- weight = rWeightDecoder.Get();
+ weight = decodedWeights[channelOutput * K + channelInput];
}
else
{
- rWeightDecoder[channelInput * outputSize + channelOutput];
- weight = rWeightDecoder.Get();
+ weight = decodedWeights[channelInput * outputSize + channelOutput];
}
- rInputDecoder[n * K + channelInput];
- outval += weight * rInputDecoder.Get();
+ outval += weight * decodedInputs[n * K + channelInput];
}
if (biasEnabled)
{
- rBiasDecoder[channelOutput];
- outval += rBiasDecoder.Get();
+ outval += decodedBiases[channelOutput];
}
rOutputEncoder[n * outputSize + channelOutput];