diff options
author | Ryan OShea <Ryan.OShea2@arm.com> | 2020-05-26 11:41:04 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-06-02 16:34:30 +0000 |
commit | 86704734edfd7f57a4339d4afcff58ad31e8ac35 (patch) | |
tree | cf45f9eec235759f8ff5341b012aa11a6f6cc7db /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | bc873d2dac4666a86e4844985199dfd90e67be5b (diff) | |
download | armnn-86704734edfd7f57a4339d4afcff58ad31e8ac35.tar.gz |
IVGCVSW-4190 Add SplitV to Tflite Parser
* Refactored SplitV
* Added unit tests
* Updated Documentation
Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Change-Id: If1dfa5a8780ddf3fe8788ed7bf7fa5fa8dfd14ec
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 93 |
1 files changed, 47 insertions, 46 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 53b49f48d0..c695caa280 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2683,7 +2683,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; - + const auto * options = operatorPtr->builtin_options.AsSplitVOptions(); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 3); @@ -2717,66 +2717,67 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes()); const unsigned int splitDim = ComputeWrappedIndex(axisData[0], inputTensorInfo.GetNumDimensions()); - // Set split sizes - const auto * options = operatorPtr->builtin_options.AsSplitOptions(); CHECK_VALID_SIZE(splitsInfo.GetNumDimensions(), 1); - unsigned int numSplits = 0; std::vector<int> splitsData(0); - if (options) + unsigned int numSplits{0}; + + if(options) { numSplits = CHECKED_NON_NEGATIVE(options->num_splits); - splitsData.resize(numSplits); - - if (inputTensorInfo.GetShape()[splitDim] % numSplits != 0) - { - throw ParseException("Number of splits must evenly divide the split axis"); - } - unsigned int splitSize = inputTensorInfo.GetShape()[splitDim] / numSplits; - for (auto& split : splitsData) - { - split = numeric_cast<int>(splitSize); - } } else { - numSplits = splitsInfo.GetShape()[0]; - splitsData.resize(numSplits); + numSplits = splitsInfo.GetNumElements(); + } + + if (numSplits <=0) + { + throw ParseException("SplitV has invalid number of splits"); + } - BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer); - ::memcpy(splitsData.data(), splitsBufferPtr->data.data(), splitsInfo.GetNumBytes()); + splitsData.resize(numSplits); + BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer); + unsigned int idx{0}; - int numInferred = 0; - int specifiedSizes = 0; - unsigned int inferIdx = 0; - unsigned int idx = 0; - for (auto split : splitsData) + for(auto& split: splitsData) + { + split = splitsBufferPtr->data[idx]; + idx++; + } + + idx = 0; + int numInferred{0}; + unsigned int inferIdx{0}; + int splitSum{0}; + for (auto split : splitsData) + { + if (split < 0) { - if (split < 0) - { - numInferred++; - inferIdx = idx; - } - else - { - specifiedSizes += split; - } - idx++; + numInferred++; + inferIdx = idx; } - - if (numInferred > 0) + else { - if (numInferred > 1) - { - throw ParseException("Cannot infer split size for more than one split"); - } - splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - specifiedSizes; + splitSum += split; } + idx++; } - - if (numSplits <=0) + // Check for inferred Axis + if (numInferred == 0) { - throw ParseException("SplitV has invalid number of splits"); + if (splitSum != numeric_cast<int>(inputTensorInfo.GetShape()[splitDim])) + { + throw ParseException("SplitV split_sizes does not sum to the dimension of value along split_dim."); + } + } + else if (numInferred == 1) + { + splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum; + } + else + { + throw ParseException("Cannot infer split size for more than one split"); } //Ouput size validation @@ -2805,7 +2806,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) accumSplit += splitSize; } - auto layerName = boost::str(boost::format("Split:%1%:%2%") % subgraphIndex % operatorIndex); + auto layerName = boost::str(boost::format("SplitV:%1%:%2%") % subgraphIndex % operatorIndex); IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str()); auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); |