aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/TransposeConvolution2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/TransposeConvolution2d.cpp')
-rw-r--r--src/backends/reference/workloads/TransposeConvolution2d.cpp90
1 files changed, 62 insertions, 28 deletions
diff --git a/src/backends/reference/workloads/TransposeConvolution2d.cpp b/src/backends/reference/workloads/TransposeConvolution2d.cpp
index 5698014181..c34a309806 100644
--- a/src/backends/reference/workloads/TransposeConvolution2d.cpp
+++ b/src/backends/reference/workloads/TransposeConvolution2d.cpp
@@ -30,27 +30,35 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
- unsigned int numBatches = inputShape[0];
+ const unsigned int numBatches = inputShape[0];
- unsigned int inputWidth = inputShape[widthIndex];
- unsigned int inputHeight = inputShape[heightIndex];
- unsigned int inputDepth = inputShape[channelsIndex];
+ const unsigned int inputWidth = inputShape[widthIndex];
+ const unsigned int inputHeight = inputShape[heightIndex];
+ const unsigned int inputDepth = inputShape[channelsIndex];
- unsigned int weightsHeight = weightsShape[heightIndex];
- unsigned int weightsWidth = weightsShape[widthIndex];
+ const unsigned int weightsHeight = weightsShape[heightIndex];
+ const unsigned int weightsWidth = weightsShape[widthIndex];
+ const unsigned int weightsDepth = weightsShape[channelsIndex];
- unsigned int outputHeight = outputShape[heightIndex];
- unsigned int outputWidth = outputShape[widthIndex];
- unsigned int outputDepth = outputShape[channelsIndex];
+ const unsigned int outputHeight = outputShape[heightIndex];
+ const unsigned int outputWidth = outputShape[widthIndex];
+ const unsigned int outputDepth = outputShape[channelsIndex];
- unsigned int paddingLeft = descriptor.m_PadLeft;
- unsigned int paddingTop = descriptor.m_PadTop;
+ const unsigned int paddingLeft = descriptor.m_PadLeft;
+ const unsigned int paddingTop = descriptor.m_PadTop;
- unsigned int strideX = descriptor.m_StrideX;
- unsigned int strideY = descriptor.m_StrideY;
+ const unsigned int strideX = descriptor.m_StrideX;
+ const unsigned int strideY = descriptor.m_StrideY;
std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
+ const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape.GetNumElements());
+
+ const unsigned channelStep = weightsWidth * weightsHeight * weightsDepth;
+
+ const std::vector<float> filterVec =
+ weightsDecoder.DecodeTensor(weightsShape.GetNumElements(), channelStep);
+
for (unsigned int batch = 0u; batch < numBatches; ++batch)
{
for (unsigned int yInput = 0u; yInput < inputHeight; ++yInput)
@@ -73,25 +81,51 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
{
for (unsigned int dInput = 0u; dInput < inputDepth; dInput++)
{
- const unsigned int inputIndex =
- dataLayoutIndexed.GetIndex(inputShape, batch, dInput, yInput, xInput);
- inputDecoder[inputIndex];
-
- const unsigned int weightsIndex =
- dataLayoutIndexed.GetIndex(weightsShape, dOutput, dInput, yWeights, xWeights);
- weightsDecoder.SetIndex(weightsIndex, dOutput);
-
- const unsigned int outputIndex =
- dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
- outputEncoder[outputIndex];
-
- float output = outputBuffer[outputIndex];
- output += inputDecoder.Get() * weightsDecoder.Get();
- outputBuffer[outputIndex] = output;
+ unsigned int inputIndex;
+ unsigned int outputIndex;
+ unsigned int weightsIndex;
+
+ if(descriptor.m_DataLayout == armnn::DataLayout::NHWC)
+ {
+ inputIndex = batch * inputHeight * inputWidth * inputDepth +
+ yInput * inputWidth * inputDepth +
+ xInput * inputDepth +
+ dInput;
+
+ weightsIndex = dOutput * weightsHeight * weightsWidth * weightsDepth +
+ yWeights * weightsWidth * weightsDepth +
+ xWeights * weightsDepth +
+ dInput;
+
+ outputIndex = batch * outputHeight * outputWidth * outputDepth +
+ yOutput * outputWidth * outputDepth +
+ xOutput * outputDepth +
+ dOutput;
+ }
+ else
+ {
+ inputIndex = batch * inputDepth * inputHeight * inputWidth +
+ dInput * inputHeight * inputWidth +
+ yInput * inputWidth +
+ xInput;
+
+ weightsIndex = dOutput * weightsDepth * weightsHeight * weightsWidth +
+ dInput * weightsHeight * weightsWidth +
+ yWeights * weightsWidth +
+ xWeights;
+
+ outputIndex = batch * outputDepth * outputHeight * outputWidth +
+ dOutput * outputHeight * outputWidth +
+ yOutput * outputWidth +
+ xOutput;
+ }
+
+ outputBuffer[outputIndex] += inputVec[inputIndex] * filterVec[weightsIndex];
}
}
}
}
+
}
}
}