aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ConvImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ConvImpl.hpp')
-rw-r--r--src/backends/reference/workloads/ConvImpl.hpp21
1 files changed, 13 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp
index 4c9ab2a644..60a3622c55 100644
--- a/src/backends/reference/workloads/ConvImpl.hpp
+++ b/src/backends/reference/workloads/ConvImpl.hpp
@@ -63,21 +63,26 @@ static void ConvImpl(ConvData data,
throw InvalidArgumentException("Bias is enabled but the bias data is invalid");
}
- const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
+ const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
+ const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
+ const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
+ const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
+ const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
+
unsigned int depthMult = depthwise ? filterInfo.GetShape()[0] : 1;
- unsigned int channelsInput = filterInfo.GetShape()[1];
+ unsigned int channelsInput = filterInfo.GetShape()[channelsIndex];
unsigned int channelsOutput = depthwise ? channelsInput * depthMult : filterInfo.GetShape()[0];
unsigned int batchSize = outputInfo0.GetShape()[0];
- unsigned int heightOutput = outputInfo0.GetShape()[2];
- unsigned int widthOutput = outputInfo0.GetShape()[3];
- unsigned int heightInput = inputInfo0.GetShape()[2];
- unsigned int widthInput = inputInfo0.GetShape()[3];
+ unsigned int heightOutput = outputInfo0.GetShape()[heightIndex];
+ unsigned int widthOutput = outputInfo0.GetShape()[widthIndex];
+ unsigned int heightInput = inputInfo0.GetShape()[heightIndex];
+ unsigned int widthInput = inputInfo0.GetShape()[widthIndex];
- unsigned int heightFilter = filterInfo.GetShape()[2];
- unsigned int widthFilter = filterInfo.GetShape()[3];
+ unsigned int heightFilter = filterInfo.GetShape()[heightIndex];
+ unsigned int widthFilter = filterInfo.GetShape()[widthIndex];
unsigned int paddingTop = data.m_Parameters.m_PadTop;
unsigned int paddingLeft = data.m_Parameters.m_PadLeft;