aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/FullyConnected.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/FullyConnected.cpp')
-rw-r--r--src/backends/reference/workloads/FullyConnected.cpp19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/FullyConnected.cpp b/src/backends/reference/workloads/FullyConnected.cpp
index 8016c1b628..61c8e88bce 100644
--- a/src/backends/reference/workloads/FullyConnected.cpp
+++ b/src/backends/reference/workloads/FullyConnected.cpp
@@ -14,6 +14,7 @@ void FullyConnected(const TensorShape& rInputShape,
Decoder<float>& rInputDecoder,
const TensorShape& rOutputShape,
Encoder<float>& rOutputEncoder,
+ const TensorShape& rWeightsShape,
Decoder<float>& rWeightDecoder,
Decoder<float>& rBiasDecoder,
const bool biasEnabled,
@@ -23,6 +24,12 @@ void FullyConnected(const TensorShape& rInputShape,
// Perform FullyConnected implementation
unsigned int outputSize = rOutputShape[1];
+ const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape.GetNumElements());
+ const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape.GetNumElements());
+ const std::vector<float> decodedBiases = biasEnabled ?
+ rBiasDecoder.DecodeTensor(outputSize) : std::vector<float>();
+
+
for (unsigned int n = 0; n < rInputShape[0]; n++)
{
for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
@@ -34,23 +41,19 @@ void FullyConnected(const TensorShape& rInputShape,
float weight;
if (transposeWeights)
{
- rWeightDecoder[channelOutput * K + channelInput];
- weight = rWeightDecoder.Get();
+ weight = decodedWeights[channelOutput * K + channelInput];
}
else
{
- rWeightDecoder[channelInput * outputSize + channelOutput];
- weight = rWeightDecoder.Get();
+ weight = decodedWeights[channelInput * outputSize + channelOutput];
}
- rInputDecoder[n * K + channelInput];
- outval += weight * rInputDecoder.Get();
+ outval += weight * decodedInputs[n * K + channelInput];
}
if (biasEnabled)
{
- rBiasDecoder[channelOutput];
- outval += rBiasDecoder.Get();
+ outval += decodedBiases[channelOutput];
}
rOutputEncoder[n * outputSize + channelOutput];