diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 12 | ||||
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 9 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/Conv3D.cpp | 3 |
3 files changed, 20 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 104a55e675..81d491a1a1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -32,8 +32,6 @@ #include <fmt/format.h> -#include <tensorflow/lite/version.h> - #include <algorithm> #include <fstream> #include <iostream> @@ -642,7 +640,10 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_CAST] = &TfLiteParserImpl::ParseCast; m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParserImpl::ParseConcatenation; m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParserImpl::ParseConv2D; + // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. + #if defined(ARMNN_POST_TFLITE_2_3) m_ParserFunctions[tflite::BuiltinOperator_CONV_3D] = &TfLiteParserImpl::ParseConv3D; + #endif m_ParserFunctions[tflite::BuiltinOperator_CUSTOM] = &TfLiteParserImpl::ParseCustomOperator; m_ParserFunctions[tflite::BuiltinOperator_DEPTH_TO_SPACE] = &TfLiteParserImpl::ParseDepthToSpace; m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParserImpl::ParseDepthwiseConv2D; @@ -772,7 +773,7 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel() auto const& opCodePtr = m_Model->operator_codes[op->opcode_index]; // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner -#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) +#if defined(ARMNN_POST_TFLITE_2_3) auto builtinCode = std::max(opCodePtr->builtin_code, static_cast<tflite::BuiltinOperator>(opCodePtr->deprecated_builtin_code)); #else @@ -899,7 +900,7 @@ void TfLiteParserImpl::ParseUnsupportedOperator(size_t subgraphIndex, size_t ope auto opcodeIndex = operatorPtr->opcode_index; // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner -#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) +#if defined(ARMNN_POST_TFLITE_2_3) auto opcode = std::max(m_Model->operator_codes[opcodeIndex]->builtin_code, static_cast<tflite::BuiltinOperator>(m_Model->operator_codes[opcodeIndex]->deprecated_builtin_code)); #else @@ -1049,6 +1050,8 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +// Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. +#if defined(ARMNN_POST_TFLITE_2_3) void TfLiteParserImpl::ParseConv3D(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); @@ -1132,6 +1135,7 @@ void TfLiteParserImpl::ParseConv3D(size_t subgraphIndex, size_t operatorIndex) auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +#endif void TfLiteParserImpl::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex) { diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 8eb529963b..3d4fd6504f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -13,6 +13,12 @@ #include <unordered_map> #include <vector> +#include <tensorflow/lite/version.h> + +#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) +#define ARMNN_POST_TFLITE_2_3 +#endif + namespace armnnTfLiteParser { @@ -113,7 +119,10 @@ private: void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation); void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex); void ParseConv2D(size_t subgraphIndex, size_t operatorIndex); + // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. + #if defined(ARMNN_POST_TFLITE_2_3) void ParseConv3D(size_t subgraphIndex, size_t operatorIndex); + #endif void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex); void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex); void ParseDequantize(size_t subgraphIndex, size_t operatorIndex); diff --git a/src/armnnTfLiteParser/test/Conv3D.cpp b/src/armnnTfLiteParser/test/Conv3D.cpp index 32cd6fe3f4..dd55aea211 100644 --- a/src/armnnTfLiteParser/test/Conv3D.cpp +++ b/src/armnnTfLiteParser/test/Conv3D.cpp @@ -6,6 +6,8 @@ #include "ParserFlatbuffersFixture.hpp" #include <sstream> +// Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. +#if defined(ARMNN_POST_TFLITE_2_3) TEST_SUITE("TensorflowLiteParser_Conv3D") { struct SimpleConv3DFixture : public ParserFlatbuffersFixture @@ -284,3 +286,4 @@ TEST_CASE_FIXTURE(Relu6Conv3DWithBiasesFixture, "ParseConv3DAndRelu6WithBias") } } +#endif |