aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorRyan OShea <Ryan.OShea2@arm.com>2020-05-26 11:41:04 +0100
committerJim Flynn <jim.flynn@arm.com>2020-06-02 16:34:30 +0000
commit86704734edfd7f57a4339d4afcff58ad31e8ac35 (patch)
treecf45f9eec235759f8ff5341b012aa11a6f6cc7db /src/armnnTfLiteParser/TfLiteParser.cpp
parentbc873d2dac4666a86e4844985199dfd90e67be5b (diff)
downloadarmnn-86704734edfd7f57a4339d4afcff58ad31e8ac35.tar.gz
IVGCVSW-4190 Add SplitV to Tflite Parser
* Refactored SplitV * Added unit tests * Updated Documentation Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com> Change-Id: If1dfa5a8780ddf3fe8788ed7bf7fa5fa8dfd14ec
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp93
1 files changed, 47 insertions, 46 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 53b49f48d0..c695caa280 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -2683,7 +2683,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
-
+ const auto * options = operatorPtr->builtin_options.AsSplitVOptions();
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(inputs.size(), 3);
@@ -2717,66 +2717,67 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
const unsigned int splitDim = ComputeWrappedIndex(axisData[0], inputTensorInfo.GetNumDimensions());
-
// Set split sizes
- const auto * options = operatorPtr->builtin_options.AsSplitOptions();
CHECK_VALID_SIZE(splitsInfo.GetNumDimensions(), 1);
- unsigned int numSplits = 0;
std::vector<int> splitsData(0);
- if (options)
+ unsigned int numSplits{0};
+
+ if(options)
{
numSplits = CHECKED_NON_NEGATIVE(options->num_splits);
- splitsData.resize(numSplits);
-
- if (inputTensorInfo.GetShape()[splitDim] % numSplits != 0)
- {
- throw ParseException("Number of splits must evenly divide the split axis");
- }
- unsigned int splitSize = inputTensorInfo.GetShape()[splitDim] / numSplits;
- for (auto& split : splitsData)
- {
- split = numeric_cast<int>(splitSize);
- }
}
else
{
- numSplits = splitsInfo.GetShape()[0];
- splitsData.resize(numSplits);
+ numSplits = splitsInfo.GetNumElements();
+ }
+
+ if (numSplits <=0)
+ {
+ throw ParseException("SplitV has invalid number of splits");
+ }
- BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer);
- ::memcpy(splitsData.data(), splitsBufferPtr->data.data(), splitsInfo.GetNumBytes());
+ splitsData.resize(numSplits);
+ BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer);
+ unsigned int idx{0};
- int numInferred = 0;
- int specifiedSizes = 0;
- unsigned int inferIdx = 0;
- unsigned int idx = 0;
- for (auto split : splitsData)
+ for(auto& split: splitsData)
+ {
+ split = splitsBufferPtr->data[idx];
+ idx++;
+ }
+
+ idx = 0;
+ int numInferred{0};
+ unsigned int inferIdx{0};
+ int splitSum{0};
+ for (auto split : splitsData)
+ {
+ if (split < 0)
{
- if (split < 0)
- {
- numInferred++;
- inferIdx = idx;
- }
- else
- {
- specifiedSizes += split;
- }
- idx++;
+ numInferred++;
+ inferIdx = idx;
}
-
- if (numInferred > 0)
+ else
{
- if (numInferred > 1)
- {
- throw ParseException("Cannot infer split size for more than one split");
- }
- splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - specifiedSizes;
+ splitSum += split;
}
+ idx++;
}
-
- if (numSplits <=0)
+ // Check for inferred Axis
+ if (numInferred == 0)
{
- throw ParseException("SplitV has invalid number of splits");
+ if (splitSum != numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
+ {
+ throw ParseException("SplitV split_sizes does not sum to the dimension of value along split_dim.");
+ }
+ }
+ else if (numInferred == 1)
+ {
+ splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
+ }
+ else
+ {
+ throw ParseException("Cannot infer split size for more than one split");
}
//Ouput size validation
@@ -2805,7 +2806,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
accumSplit += splitSize;
}
- auto layerName = boost::str(boost::format("Split:%1%:%2%") % subgraphIndex % operatorIndex);
+ auto layerName = boost::str(boost::format("SplitV:%1%:%2%") % subgraphIndex % operatorIndex);
IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));