diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 21 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/Reduce.cpp | 4 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/Sum.cpp | 2 |
3 files changed, 15 insertions, 12 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 8ce1667557..ab32ef7822 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -3091,19 +3091,24 @@ void TfLiteParserImpl::ParseReduce(size_t subgraphIndex, size_t operatorIndex, R armnn::TensorInfo inputTensorInfo0 = ToTensorInfo(inputs[0]); armnn::TensorInfo inputTensorInfo1 = ToTensorInfo(inputs[1]); - TensorShape input0Shape = inputTensorInfo0.GetShape(); ReduceDescriptor desc; - BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); // Get const axis value from model and set it to descriptor. if (axisBufferPtr != nullptr) { - for (uint32_t i = 0; i < inputTensorInfo1.GetNumElements(); ++i) - { - desc.m_vAxis.push_back(armnnUtils::GetUnsignedAxis(inputTensorInfo0.GetNumDimensions(), - axisBufferPtr->data.data()[i])); - } + std::vector<int32_t> axisData(inputTensorInfo1.GetNumElements()); + ::memcpy(axisData.data(), axisBufferPtr->data.data(), inputTensorInfo1.GetNumBytes()); + + // Convert the axis to unsigned int and remove duplicates. + auto rank = static_cast<int32_t>(inputTensorInfo0.GetNumDimensions()); + std::set<unsigned int> uniqueAxis; + std::transform(axisData.begin(), + axisData.end(), + std::inserter(uniqueAxis, uniqueAxis.begin()), + [rank](int i)->unsigned int{ + return static_cast<uint32_t>(((i + rank) % rank)); }); + desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end()); } else { @@ -3113,8 +3118,6 @@ void TfLiteParserImpl::ParseReduce(size_t subgraphIndex, size_t operatorIndex, R } } - desc.m_TargetHeight = input0Shape[1]; - desc.m_TargetWidth = input0Shape[2]; desc.m_KeepDims = options->keep_dims; desc.m_ReduceOperation = reduceOperation; diff --git a/src/armnnTfLiteParser/test/Reduce.cpp b/src/armnnTfLiteParser/test/Reduce.cpp index 622d54e8b5..c2a22f0b86 100644 --- a/src/armnnTfLiteParser/test/Reduce.cpp +++ b/src/armnnTfLiteParser/test/Reduce.cpp @@ -90,7 +90,7 @@ struct ReduceMaxFixture : public ParserFlatbuffersFixture struct SimpleReduceMaxFixture : public ReduceMaxFixture { - SimpleReduceMaxFixture() : ReduceMaxFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2 ]") {} + SimpleReduceMaxFixture() : ReduceMaxFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2,0,0,0 ]") {} }; BOOST_FIXTURE_TEST_CASE(ParseReduceMax, SimpleReduceMaxFixture) @@ -179,7 +179,7 @@ struct ReduceMinFixture : public ParserFlatbuffersFixture struct SimpleReduceMinFixture : public ReduceMinFixture { - SimpleReduceMinFixture() : ReduceMinFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2 ]") {} + SimpleReduceMinFixture() : ReduceMinFixture("[ 1, 1, 2, 3 ]", "[ 1, 1, 1, 3 ]", "[ 1 ]", "[ 2, 0, 0, 0 ]") {} }; BOOST_FIXTURE_TEST_CASE(ParseReduceMin, SimpleReduceMinFixture) diff --git a/src/armnnTfLiteParser/test/Sum.cpp b/src/armnnTfLiteParser/test/Sum.cpp index 22b19ae058..177bcd52de 100644 --- a/src/armnnTfLiteParser/test/Sum.cpp +++ b/src/armnnTfLiteParser/test/Sum.cpp @@ -90,7 +90,7 @@ struct SumFixture : public ParserFlatbuffersFixture struct SimpleSumFixture : public SumFixture { - SimpleSumFixture() : SumFixture("[ 1, 3, 2, 4 ]", "[ 1, 1, 1, 4 ]", "[ 2 ]", "[ 1, 2 ]") {} + SimpleSumFixture() : SumFixture("[ 1, 3, 2, 4 ]", "[ 1, 1, 1, 4 ]", "[ 2 ]", "[ 1, 0, 0, 0, 2, 0, 0, 0 ]") {} }; BOOST_FIXTURE_TEST_CASE(ParseSum, SimpleSumFixture) |