diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 18ab4a8709..d5e3638a06 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -6,6 +6,8 @@ #include "CpuTensorHandle.hpp" +#include <backendsCommon/DataLayoutIndexed.hpp> + #include <algorithm> #include <iomanip> #include <string> @@ -675,10 +677,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c } { + DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout); const unsigned int inputChannelCount = - workloadInfo.m_InputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()]; + workloadInfo.m_InputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()]; const unsigned int outputChannelCount = - workloadInfo.m_OutputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()]; + workloadInfo.m_OutputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()]; if (inputChannelCount != outputChannelCount) { throw InvalidArgumentException( @@ -774,14 +777,15 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0]; std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1]; - unsigned int inputHeight = inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout); + unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] + heightPad.first + heightPad.second; - unsigned int inputWidth = inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] + widthPad.first + widthPad.second; unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth - * inputShape[m_Parameters.m_DataLayout.GetChannelsIndex()]; + * inputShape[dimensionIndices.GetChannelsIndex()]; if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements) { |