aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2018-11-19 13:19:28 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-11-19 16:14:07 +0000
commit8800c00770ed14eb48045cfcf033d6b67595a126 (patch)
tree3bdbd3a97bfc21276a98a14aeae3e878c96c3121 /src/backends/backendsCommon/WorkloadData.cpp
parent5cdda351b4e12c5299173ec6b0fc75a948bdcda0 (diff)
downloadarmnn-8800c00770ed14eb48045cfcf033d6b67595a126.tar.gz
IVGCVSW-2169 Remove DataLayoutIndexed from public API
Change-Id: If8d8087d9d365e467d3ca9bf9c40d7219cb75cfd
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 18ab4a8709..d5e3638a06 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -6,6 +6,8 @@
#include "CpuTensorHandle.hpp"
+#include <backendsCommon/DataLayoutIndexed.hpp>
+
#include <algorithm>
#include <iomanip>
#include <string>
@@ -675,10 +677,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
}
{
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
const unsigned int inputChannelCount =
- workloadInfo.m_InputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()];
+ workloadInfo.m_InputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()];
const unsigned int outputChannelCount =
- workloadInfo.m_OutputTensorInfos[0].GetShape()[this->m_Parameters.m_DataLayout.GetChannelsIndex()];
+ workloadInfo.m_OutputTensorInfos[0].GetShape()[dimensionIndices.GetChannelsIndex()];
if (inputChannelCount != outputChannelCount)
{
throw InvalidArgumentException(
@@ -774,14 +777,15 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
- unsigned int inputHeight = inputShape[m_Parameters.m_DataLayout.GetHeightIndex()]
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
+ unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()]
+ heightPad.first + heightPad.second;
- unsigned int inputWidth = inputShape[m_Parameters.m_DataLayout.GetWidthIndex()]
+ unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()]
+ widthPad.first + widthPad.second;
unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth
- * inputShape[m_Parameters.m_DataLayout.GetChannelsIndex()];
+ * inputShape[dimensionIndices.GetChannelsIndex()];
if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements)
{