aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp21
-rw-r--r--src/armnnTfLiteParser/test/Reduce.cpp4
-rw-r--r--src/armnnTfLiteParser/test/Sum.cpp2
3 files changed, 15 insertions, 12 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8ce1667557..ab32ef7822 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -3091,19 +3091,24 @@ void TfLiteParserImpl::ParseReduce(size_t subgraphIndex, size_t operatorIndex, R
armnn::TensorInfo inputTensorInfo0 = ToTensorInfo(inputs[0]);
armnn::TensorInfo inputTensorInfo1 = ToTensorInfo(inputs[1]);
- TensorShape input0Shape = inputTensorInfo0.GetShape();
ReduceDescriptor desc;
-
BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
// Get const axis value from model and set it to descriptor.
if (axisBufferPtr != nullptr)
{
- for (uint32_t i = 0; i < inputTensorInfo1.GetNumElements(); ++i)
- {
- desc.m_vAxis.push_back(armnnUtils::GetUnsignedAxis(inputTensorInfo0.GetNumDimensions(),
- axisBufferPtr->data.data()[i]));
- }
+ std::vector<int32_t> axisData(inputTensorInfo1.GetNumElements());
+ ::memcpy(axisData.data(), axisBufferPtr->data.data(), inputTensorInfo1.GetNumBytes());
+
+ // Convert the axis to unsigned int and remove duplicates.
+ auto rank = static_cast<int32_t>(inputTensorInfo0.GetNumDimensions());
+ std::set<unsigned int> uniqueAxis;
+ std::transform(axisData.begin(),
+ axisData.end(),
+ std::inserter(uniqueAxis, uniqueAxis.begin()),
+ [rank](int i)->unsigned int{
+ return static_cast<uint32_t>(((i + rank) % rank)); });
+ desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
}
else
{
@@ -3113,8 +3118,6 @@ void TfLiteParserImpl::ParseReduce(size_t subgraphIndex, size_t operatorIndex, R
}
}
- desc.m_TargetHeight = input0Shape[1];
- desc.m_TargetWidth = input0Shape[2];
desc.m_KeepDims = options->keep_dims;
desc.m_ReduceOperation = reduceOperation;
diff --git a/src/armnnTfLiteParser/test/Reduce.cpp b/src/armnnTfLiteParser/test/Reduce.cpp
index 622d54e8b5..c2a22f0b86 100644
--- a/src/armnnTfLiteParser/test/Reduce.cpp
+++ b/src/armnnTfLiteParser/test/Reduce.cpp
@@ -90,7 +90,7 @@ struct ReduceMaxFixture : public ParserFlatbuffersFixture
struct SimpleReduceMaxFixture : public ReduceMaxFixture
{
- SimpleReduceMaxFixture() : ReduceMaxFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2 ]") {}
+ SimpleReduceMaxFixture() : ReduceMaxFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2,0,0,0 ]") {}
};
BOOST_FIXTURE_TEST_CASE(ParseReduceMax, SimpleReduceMaxFixture)
@@ -179,7 +179,7 @@ struct ReduceMinFixture : public ParserFlatbuffersFixture
struct SimpleReduceMinFixture : public ReduceMinFixture
{
- SimpleReduceMinFixture() : ReduceMinFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2 ]") {}
+ SimpleReduceMinFixture() : ReduceMinFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2, 0, 0, 0 ]") {}
};
BOOST_FIXTURE_TEST_CASE(ParseReduceMin, SimpleReduceMinFixture)
diff --git a/src/armnnTfLiteParser/test/Sum.cpp b/src/armnnTfLiteParser/test/Sum.cpp
index 22b19ae058..177bcd52de 100644
--- a/src/armnnTfLiteParser/test/Sum.cpp
+++ b/src/armnnTfLiteParser/test/Sum.cpp
@@ -90,7 +90,7 @@ struct SumFixture : public ParserFlatbuffersFixture
struct SimpleSumFixture : public SumFixture
{
- SimpleSumFixture() : SumFixture("[ 1, 3, 2, 4 ]", "[ 1, 1, 1, 4 ]", "[ 2 ]", "[ 1, 2 ]") {}
+ SimpleSumFixture() : SumFixture("[ 1, 3, 2, 4 ]", "[ 1, 1, 1, 4 ]", "[ 2 ]", "[ 1, 0, 0, 0, 2, 0, 0, 0 ]") {}
};
BOOST_FIXTURE_TEST_CASE(ParseSum, SimpleSumFixture)