diff options
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r-- | src/backends/reference/workloads/FullyConnected.cpp | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp index 8016c1b628..61c8e88bce 100644 --- a/src/backends/reference/workloads/FullyConnected.cpp +++ b/src/backends/reference/workloads/FullyConnected.cpp @@ -14,6 +14,7 @@ void FullyConnected(const TensorShape& rInputShape, Decoder<float>& rInputDecoder, const TensorShape& rOutputShape, Encoder<float>& rOutputEncoder, + const TensorShape& rWeightsShape, Decoder<float>& rWeightDecoder, Decoder<float>& rBiasDecoder, const bool biasEnabled, @@ -23,6 +24,12 @@ void FullyConnected(const TensorShape& rInputShape, // Perform FullyConnected implementation unsigned int outputSize = rOutputShape[1]; + const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape.GetNumElements()); + const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape.GetNumElements()); + const std::vector<float> decodedBiases = biasEnabled ? + rBiasDecoder.DecodeTensor(outputSize) : std::vector<float>(); + + for (unsigned int n = 0; n < rInputShape[0]; n++) { for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++) @@ -34,23 +41,19 @@ void FullyConnected(const TensorShape& rInputShape, float weight; if (transposeWeights) { - rWeightDecoder[channelOutput * K + channelInput]; - weight = rWeightDecoder.Get(); + weight = decodedWeights[channelOutput * K + channelInput]; } else { - rWeightDecoder[channelInput * outputSize + channelOutput]; - weight = rWeightDecoder.Get(); + weight = decodedWeights[channelInput * outputSize + channelOutput]; } - rInputDecoder[n * K + channelInput]; - outval += weight * rInputDecoder.Get(); + outval += weight * decodedInputs[n * K + channelInput]; } if (biasEnabled) { - rBiasDecoder[channelOutput]; - outval += rBiasDecoder.Get(); + outval += decodedBiases[channelOutput]; } rOutputEncoder[n * outputSize + channelOutput]; |