aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/TransposeConvolution2d.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-09-17 15:58:31 +0100
committerfinn.williams <finn.williams@arm.com>2020-09-28 09:01:58 +0000
commitb9dcfe63b87f024c6f8c5f4b68447de04119dc19 (patch)
tree0c58376c59190ecbc8df0dd2abedbf85983d5256 /src/backends/reference/workloads/TransposeConvolution2d.cpp
parentbe727becad9fe048480ab53a0281b46594f95ca7 (diff)
downloadarmnn-b9dcfe63b87f024c6f8c5f4b68447de04119dc19.tar.gz
IVGCVSW-5325 Speed up the reference backend
Change-Id: Id8bd0a0418be31d975b944b54bbacb25051ffb2e Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/TransposeConvolution2d.cpp')
-rw-r--r--src/backends/reference/workloads/TransposeConvolution2d.cpp90
1 files changed, 62 insertions, 28 deletions
diff --git a/src/backends/reference/workloads/TransposeConvolution2d.cpp b/src/backends/reference/workloads/TransposeConvolution2d.cpp
index 5698014181..c34a309806 100644
--- a/src/backends/reference/workloads/TransposeConvolution2d.cpp
+++ b/src/backends/reference/workloads/TransposeConvolution2d.cpp
@@ -30,27 +30,35 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
- unsigned int numBatches = inputShape[0];
+ const unsigned int numBatches = inputShape[0];
- unsigned int inputWidth = inputShape[widthIndex];
- unsigned int inputHeight = inputShape[heightIndex];
- unsigned int inputDepth = inputShape[channelsIndex];
+ const unsigned int inputWidth = inputShape[widthIndex];
+ const unsigned int inputHeight = inputShape[heightIndex];
+ const unsigned int inputDepth = inputShape[channelsIndex];
- unsigned int weightsHeight = weightsShape[heightIndex];
- unsigned int weightsWidth = weightsShape[widthIndex];
+ const unsigned int weightsHeight = weightsShape[heightIndex];
+ const unsigned int weightsWidth = weightsShape[widthIndex];
+ const unsigned int weightsDepth = weightsShape[channelsIndex];
- unsigned int outputHeight = outputShape[heightIndex];
- unsigned int outputWidth = outputShape[widthIndex];
- unsigned int outputDepth = outputShape[channelsIndex];
+ const unsigned int outputHeight = outputShape[heightIndex];
+ const unsigned int outputWidth = outputShape[widthIndex];
+ const unsigned int outputDepth = outputShape[channelsIndex];
- unsigned int paddingLeft = descriptor.m_PadLeft;
- unsigned int paddingTop = descriptor.m_PadTop;
+ const unsigned int paddingLeft = descriptor.m_PadLeft;
+ const unsigned int paddingTop = descriptor.m_PadTop;
- unsigned int strideX = descriptor.m_StrideX;
- unsigned int strideY = descriptor.m_StrideY;
+ const unsigned int strideX = descriptor.m_StrideX;
+ const unsigned int strideY = descriptor.m_StrideY;
std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
+ const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape.GetNumElements());
+
+ const unsigned channelStep = weightsWidth * weightsHeight * weightsDepth;
+
+ const std::vector<float> filterVec =
+ weightsDecoder.DecodeTensor(weightsShape.GetNumElements(), channelStep);
+
for (unsigned int batch = 0u; batch < numBatches; ++batch)
{
for (unsigned int yInput = 0u; yInput < inputHeight; ++yInput)
@@ -73,25 +81,51 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
{
for (unsigned int dInput = 0u; dInput < inputDepth; dInput++)
{
- const unsigned int inputIndex =
- dataLayoutIndexed.GetIndex(inputShape, batch, dInput, yInput, xInput);
- inputDecoder[inputIndex];
-
- const unsigned int weightsIndex =
- dataLayoutIndexed.GetIndex(weightsShape, dOutput, dInput, yWeights, xWeights);
- weightsDecoder.SetIndex(weightsIndex, dOutput);
-
- const unsigned int outputIndex =
- dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
- outputEncoder[outputIndex];
-
- float output = outputBuffer[outputIndex];
- output += inputDecoder.Get() * weightsDecoder.Get();
- outputBuffer[outputIndex] = output;
+ unsigned int inputIndex;
+ unsigned int outputIndex;
+ unsigned int weightsIndex;
+
+ if(descriptor.m_DataLayout == armnn::DataLayout::NHWC)
+ {
+ inputIndex = batch * inputHeight * inputWidth * inputDepth +
+ yInput * inputWidth * inputDepth +
+ xInput * inputDepth +
+ dInput;
+
+ weightsIndex = dOutput * weightsHeight * weightsWidth * weightsDepth +
+ yWeights * weightsWidth * weightsDepth +
+ xWeights * weightsDepth +
+ dInput;
+
+ outputIndex = batch * outputHeight * outputWidth * outputDepth +
+ yOutput * outputWidth * outputDepth +
+ xOutput * outputDepth +
+ dOutput;
+ }
+ else
+ {
+ inputIndex = batch * inputDepth * inputHeight * inputWidth +
+ dInput * inputHeight * inputWidth +
+ yInput * inputWidth +
+ xInput;
+
+ weightsIndex = dOutput * weightsDepth * weightsHeight * weightsWidth +
+ dInput * weightsHeight * weightsWidth +
+ yWeights * weightsWidth +
+ xWeights;
+
+ outputIndex = batch * outputDepth * outputHeight * outputWidth +
+ dOutput * outputHeight * outputWidth +
+ yOutput * outputWidth +
+ xOutput;
+ }
+
+ outputBuffer[outputIndex] += inputVec[inputIndex] * filterVec[weightsIndex];
}
}
}
}
+
}
}
}