diff options
Diffstat (limited to 'src/backends/reference/workloads/ConvImpl.hpp')
-rw-r--r-- | src/backends/reference/workloads/ConvImpl.hpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp index 4c9ab2a644..60a3622c55 100644 --- a/src/backends/reference/workloads/ConvImpl.hpp +++ b/src/backends/reference/workloads/ConvImpl.hpp @@ -63,21 +63,26 @@ static void ConvImpl(ConvData data, throw InvalidArgumentException("Bias is enabled but the bias data is invalid"); } - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]); + const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout); + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex(); + const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex(); + unsigned int depthMult = depthwise ? filterInfo.GetShape()[0] : 1; - unsigned int channelsInput = filterInfo.GetShape()[1]; + unsigned int channelsInput = filterInfo.GetShape()[channelsIndex]; unsigned int channelsOutput = depthwise ? channelsInput * depthMult : filterInfo.GetShape()[0]; unsigned int batchSize = outputInfo0.GetShape()[0]; - unsigned int heightOutput = outputInfo0.GetShape()[2]; - unsigned int widthOutput = outputInfo0.GetShape()[3]; - unsigned int heightInput = inputInfo0.GetShape()[2]; - unsigned int widthInput = inputInfo0.GetShape()[3]; + unsigned int heightOutput = outputInfo0.GetShape()[heightIndex]; + unsigned int widthOutput = outputInfo0.GetShape()[widthIndex]; + unsigned int heightInput = inputInfo0.GetShape()[heightIndex]; + unsigned int widthInput = inputInfo0.GetShape()[widthIndex]; - unsigned int heightFilter = filterInfo.GetShape()[2]; - unsigned int widthFilter = filterInfo.GetShape()[3]; + unsigned int heightFilter = filterInfo.GetShape()[heightIndex]; + unsigned int widthFilter = filterInfo.GetShape()[widthIndex]; unsigned int paddingTop = data.m_Parameters.m_PadTop; unsigned int paddingLeft = data.m_Parameters.m_PadLeft; |