aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/reference/workloads/ConvImpl.cpp78
1 files changed, 19 insertions, 59 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp
index 6a5ac535e4..801a29af1a 100644
--- a/src/backends/reference/workloads/ConvImpl.cpp
+++ b/src/backends/reference/workloads/ConvImpl.cpp
@@ -68,26 +68,6 @@ int32_t QuantizedMultiplierSmallerThanOne::RoundingDivideByPOT(int32_t x, int ex
return (x >> exponent) + (remainder > threshold ? 1 : 0);
}
-inline unsigned int GetOffset(DataLayout& dataLayout, const TensorShape& shape, unsigned int b, unsigned int c,
- unsigned int h, unsigned int w)
-{
- switch (dataLayout)
- {
- case DataLayout::NHWC:
- b *= shape[1] * shape[2] * shape[3];
- h *= shape[2] * shape[3];
- w *= shape[3];
- break;
- case DataLayout::NCHW:
- default:
- b *= shape[1] * shape[2] * shape[3];
- c *= shape[2] * shape[3];
- h *= shape[3];
- break;
- }
- return b + c + h + w;
-}
-
void Convolve(const TensorShape& rInputShape,
Decoder<float>& rInputDecoder,
const TensorShape& rOutputShape,
@@ -167,24 +147,15 @@ void Convolve(const TensorShape& rInputShape,
}
else
{
- if (dataLayout == DataLayout::NHWC)
- {
- filterIndex = cOutput * filterHeight * filterWidth * inputChannels +
- yFilter * filterWidth * inputChannels +
- xFilter * inputChannels +
- cInput;
- }
- else
- {
- filterIndex = cOutput * filterWidth * filterHeight * inputChannels +
- cInput * filterWidth * filterHeight +
- yFilter * filterWidth +
- xFilter;
- }
+ filterIndex = dataLayoutIndexed.GetIndex(rFilterShape,
+ cOutput,
+ cInput,
+ yFilter,
+ xFilter);
}
- rFilterDecoder += filterIndex;
+
+ rFilterDecoder[filterIndex];
float filterValue = rFilterDecoder.Get();
- rFilterDecoder -= filterIndex;
unsigned int yInput = yOutput * yStride + yFilter * yDilation;
unsigned int xInput = xOutput * xStride + xFilter * xDilation;
@@ -199,26 +170,16 @@ void Convolve(const TensorShape& rInputShape,
}
else
{
- unsigned int inputIndex;
-
- if (dataLayout == DataLayout::NHWC)
- {
- inputIndex = batchIdx * inputHeight * inputWidth * inputChannels +
- (yInput - paddingTop) * inputWidth * inputChannels +
- (xInput - paddingLeft) * inputChannels +
- cInput;
- }
- else
- {
- inputIndex = batchIdx * inputWidth * inputHeight * inputChannels +
- inputWidth * inputHeight * cInput +
- inputWidth * (yInput - paddingTop) +
- xInput - paddingLeft;
- }
- rInputDecoder += inputIndex;
+ unsigned int inputIndex = dataLayoutIndexed.GetIndex(rInputShape,
+ batchIdx,
+ cInput,
+ yInput - paddingTop,
+ xInput - paddingLeft);
+
+ rInputDecoder[inputIndex];
inputValue = rInputDecoder.Get();
- rInputDecoder -= inputIndex;
}
+
sum += filterValue * inputValue;
}
}
@@ -226,15 +187,14 @@ void Convolve(const TensorShape& rInputShape,
if (biasEnabled)
{
- *pBiasDecoder += cOutput;
+ (*pBiasDecoder)[cOutput];
sum += pBiasDecoder->Get();
- *pBiasDecoder -= cOutput;
}
- unsigned int outIdx = GetOffset(dataLayout, rOutputShape, batchIdx, cOutput, yOutput, xOutput);
- rOutputEncoder += outIdx;
+ unsigned int outIdx = dataLayoutIndexed.GetIndex(rOutputShape, batchIdx, cOutput, yOutput, xOutput);
+
+ rOutputEncoder[outIdx];
rOutputEncoder.Set(sum);
- rOutputEncoder -= outIdx;
}
}
}