diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-09-17 15:58:31 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2020-09-28 09:01:58 +0000 |
commit | b9dcfe63b87f024c6f8c5f4b68447de04119dc19 (patch) | |
tree | 0c58376c59190ecbc8df0dd2abedbf85983d5256 /src/backends/reference/workloads/FullyConnected.cpp | |
parent | be727becad9fe048480ab53a0281b46594f95ca7 (diff) | |
download | armnn-b9dcfe63b87f024c6f8c5f4b68447de04119dc19.tar.gz |
IVGCVSW-5325 Speed up the reference backend
Change-Id: Id8bd0a0418be31d975b944b54bbacb25051ffb2e
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r-- | src/backends/reference/workloads/FullyConnected.cpp | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp index 8016c1b628..61c8e88bce 100644 --- a/src/backends/reference/workloads/FullyConnected.cpp +++ b/src/backends/reference/workloads/FullyConnected.cpp @@ -14,6 +14,7 @@ void FullyConnected(const TensorShape& rInputShape, Decoder<float>& rInputDecoder, const TensorShape& rOutputShape, Encoder<float>& rOutputEncoder, + const TensorShape& rWeightsShape, Decoder<float>& rWeightDecoder, Decoder<float>& rBiasDecoder, const bool biasEnabled, @@ -23,6 +24,12 @@ void FullyConnected(const TensorShape& rInputShape, // Perform FullyConnected implementation unsigned int outputSize = rOutputShape[1]; + const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape.GetNumElements()); + const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape.GetNumElements()); + const std::vector<float> decodedBiases = biasEnabled ? + rBiasDecoder.DecodeTensor(outputSize) : std::vector<float>(); + + for (unsigned int n = 0; n < rInputShape[0]; n++) { for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++) @@ -34,23 +41,19 @@ void FullyConnected(const TensorShape& rInputShape, float weight; if (transposeWeights) { - rWeightDecoder[channelOutput * K + channelInput]; - weight = rWeightDecoder.Get(); + weight = decodedWeights[channelOutput * K + channelInput]; } else { - rWeightDecoder[channelInput * outputSize + channelOutput]; - weight = rWeightDecoder.Get(); + weight = decodedWeights[channelInput * outputSize + channelOutput]; } - rInputDecoder[n * K + channelInput]; - outval += weight * rInputDecoder.Get(); + outval += weight * decodedInputs[n * K + channelInput]; } if (biasEnabled) { - rBiasDecoder[channelOutput]; - outval += rBiasDecoder.Get(); + outval += decodedBiases[channelOutput]; } rOutputEncoder[n * outputSize + channelOutput]; |