aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/reference/workloads/BatchNormImpl.cpp18
1 files changed, 1 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/BatchNormImpl.cpp b/src/backends/reference/workloads/BatchNormImpl.cpp
index 36e96d3fec..b80af8c937 100644
--- a/src/backends/reference/workloads/BatchNormImpl.cpp
+++ b/src/backends/reference/workloads/BatchNormImpl.cpp
@@ -53,23 +53,7 @@ void BatchNormImpl(const BatchNormalizationQueueDescriptor& data,
{
for (unsigned int w = 0; w < inputWidth; w++)
{
- unsigned int index = 0;
-
- if (dataLayout == DataLayout::NHWC)
- {
- index = n * inputHeight * inputWidth * inputChannels +
- h * inputWidth * inputChannels +
- w * inputChannels +
- c;
- }
- else // dataLayout == DataLayout::NCHW
- {
- index = n * inputHeight * inputWidth * inputChannels +
- c * inputHeight * inputWidth +
- h * inputWidth +
- w;
- }
-
+ unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
inputDecoder[index];
outputEncoder[index];
outputEncoder.Set(mult * inputDecoder.Get() + add);