aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 18ab4a8709..d5e3638a06 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -6,6 +6,8 @@
#include "CpuTensorHandle.hpp"
+#include <backendsCommon/DataLayoutIndexed.hpp>
+
#include <algorithm>
#include <iomanip>
#include <string>
@@ -675,10 +677,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
}
{
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
const unsigned int inputChannelCount =
- workloadInfo.m_InputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()];
+ workloadInfo.m_InputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()];
const unsigned int outputChannelCount =
- workloadInfo.m_OutputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()];
+ workloadInfo.m_OutputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()];
if (inputChannelCount != outputChannelCount)
{
throw InvalidArgumentException(
@@ -774,14 +777,15 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
- unsigned int inputHeight = inputShape[m_Parameters.m_DataLayout.GetHeightIndex()]
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
+ unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()]
+ heightPad.first + heightPad.second;
- unsigned int inputWidth = inputShape[m_Parameters.m_DataLayout.GetWidthIndex()]
+ unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()]
+ widthPad.first + widthPad.second;
unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth
- * inputShape[m_Parameters.m_DataLayout.GetChannelsIndex()];
+ * inputShape[dimensionIndices.GetChannelsIndex()];
if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements)
{