aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp57
1 files changed, 56 insertions, 1 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b5a421145a..124c5fdcc7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -378,7 +378,8 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Pad", &TfParser::ParsePad },
{ "Sub", &TfParser::ParseSub },
{ "Pack" , &TfParser::ParseStack },
- { "Stack", &TfParser::ParseStack }
+ { "Stack", &TfParser::ParseStack },
+ { "Transpose", &TfParser::ParseTranspose },
};
const std::list<std::string> TfParser::m_ControlInputs = {
@@ -2054,6 +2055,60 @@ ParsedTfOperationPtr TfParser::ParseStack(const tensorflow::NodeDef& nodeDef, co
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ auto inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ const auto inputCount = inputs.size();
+
+ if (inputCount != 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "The number of given input is %1%. It should be two for Transpose op."
+ "Node %2% %3%")
+ % inputCount
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ auto* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+
+ const auto constInput = inputs[GetConstInputIndex(inputs)];
+ auto* permuteVectorInput =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(constInput.m_IndexedValue);
+ const auto& permuteVectorInfo = permuteVectorInput->GetTensorInfo();
+
+ std::vector<int32_t> permuteVectorData;
+ permuteVectorInput->GetConstTensor(permuteVectorData);
+
+ std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.size());
+ std::vector<int32_t>::iterator it;
+
+ for (unsigned int i = 0u; i < permuteVectorData.size(); ++i)
+ {
+ it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i);
+ armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it));
+ }
+
+ const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements());
+ const auto desc = PermuteDescriptor(permutationVector);
+
+ auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str());
+ BOOST_ASSERT(layer);
+
+ input0Slot->Connect(layer->GetInputSlot(0));
+
+ const auto& input0Info = input0Slot->GetTensorInfo();
+ armnn::TensorInfo outputInfo {input0Info};
+ outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
unsigned int CheckPaddingTensor(const ConstTensor& paddingTensor,
const TensorInfo& inputTensorInfo,
const std::string& nodeName)