aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 1c5b4fc9f0..f44a747815 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -808,6 +808,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParserImpl::ParseReshape;
m_ParserFunctions[tflite::BuiltinOperator_RESIZE_BILINEAR] = &TfLiteParserImpl::ParseResizeBilinear;
m_ParserFunctions[tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] = &TfLiteParserImpl::ParseResizeNearestNeighbor;
+ m_ParserFunctions[tflite::BuiltinOperator_REVERSE_V2] = &TfLiteParserImpl::ParseReverseV2;
m_ParserFunctions[tflite::BuiltinOperator_RSQRT] = &TfLiteParserImpl::ParseRsqrt;
m_ParserFunctions[tflite::BuiltinOperator_SQRT] = &TfLiteParserImpl::ParseSqrt;
m_ParserFunctions[tflite::BuiltinOperator_SHAPE] = &TfLiteParserImpl::ParseShape;
@@ -3276,6 +3277,67 @@ void TfLiteParserImpl::ParseResize(size_t subgraphIndex, size_t operatorIndex, R
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParserImpl::ParseReverseV2(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = fmt::format("ReverseV2:{}:{}", subgraphIndex, operatorIndex);
+
+ TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
+ TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+
+ std::vector<int32_t> axisTensorData(axisTensorInfo.GetNumElements());
+
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ ::memcpy(axisTensorData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
+
+ ReverseV2Descriptor descriptor(axisTensorData);
+
+ auto inputRank = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ std::vector<bool> dimFlag(inputRank, false);
+
+ for (auto axis : axisTensorData)
+ {
+ if (axis < -inputRank || axis >= inputRank)
+ {
+ throw ParseException(
+ fmt::format("Operation has invalid axis: {} It is out of bounds [ -{}, {} ) {}",
+ axis,
+ inputRank, inputRank,
+ CHECK_LOCATION().AsString()));
+ }
+
+ auto posAxis = axis < 0 ? axis + inputRank : axis;
+
+ if (dimFlag[posAxis])
+ {
+ throw ParseException(
+ fmt::format("Operation has repeated axis: {} {}",
+ axis,
+ CHECK_LOCATION().AsString()));
+ }
+ dimFlag[posAxis] = true;
+ }
+
+ IConnectableLayer* layer = m_Network->AddReverseV2Layer(descriptor, layerName.c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
void TfLiteParserImpl::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);