aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ConvImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ConvImpl.cpp')
-rw-r--r--src/backends/reference/workloads/ConvImpl.cpp44
1 files changed, 34 insertions, 10 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp
index 801a29af1a..92e3b2d7dd 100644
--- a/src/backends/reference/workloads/ConvImpl.cpp
+++ b/src/backends/reference/workloads/ConvImpl.cpp
@@ -147,11 +147,22 @@ void Convolve(const TensorShape& rInputShape,
}
else
{
- filterIndex = dataLayoutIndexed.GetIndex(rFilterShape,
- cOutput,
- cInput,
- yFilter,
- xFilter);
+ // Keep this implementation, as using DataLayoutIndexed::GetIndex causes great
+ // performance regression.
+ if (dataLayout == DataLayout::NHWC)
+ {
+ filterIndex = cOutput * filterHeight * filterWidth * inputChannels +
+ yFilter * filterWidth * inputChannels +
+ xFilter * inputChannels +
+ cInput;
+ }
+ else
+ {
+ filterIndex = cOutput * filterWidth * filterHeight * inputChannels +
+ cInput * filterWidth * filterHeight +
+ yFilter * filterWidth +
+ xFilter;
+ }
}
rFilterDecoder[filterIndex];
@@ -170,11 +181,24 @@ void Convolve(const TensorShape& rInputShape,
}
else
{
- unsigned int inputIndex = dataLayoutIndexed.GetIndex(rInputShape,
- batchIdx,
- cInput,
- yInput - paddingTop,
- xInput - paddingLeft);
+ unsigned int inputIndex = 0;
+
+ // Keep this implementation, as using DataLayoutIndexed::GetIndex causes great
+ // performance regression.
+ if (dataLayout == DataLayout::NHWC)
+ {
+ inputIndex = batchIdx * inputHeight * inputWidth * inputChannels +
+ (yInput - paddingTop) * inputWidth * inputChannels +
+ (xInput - paddingLeft) * inputChannels +
+ cInput;
+ }
+ else
+ {
+ inputIndex = batchIdx * inputWidth * inputHeight * inputChannels +
+ inputWidth * inputHeight * cInput +
+ inputWidth * (yInput - paddingTop) +
+ xInput - paddingLeft;
+ }
rInputDecoder[inputIndex];
inputValue = rInputDecoder.Get();