diff options
Diffstat (limited to 'src/armnn/LayerSupportCommon.hpp')
-rw-r--r-- | src/armnn/LayerSupportCommon.hpp | 59 |
1 files changed, 56 insertions, 3 deletions
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp index 5b7feac387..63065c0565 100644 --- a/src/armnn/LayerSupportCommon.hpp +++ b/src/armnn/LayerSupportCommon.hpp @@ -11,17 +11,20 @@ namespace armnn { -template<typename Float32Func, typename Uint8Func, typename ... Params> +template<typename Float16Func, typename Float32Func, typename Uint8Func, typename ... Params> bool IsSupportedForDataTypeGeneric(std::string* reasonIfUnsupported, DataType dataType, - Float32Func floatFuncPtr, + Float16Func float16FuncPtr, + Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, Params&&... params) { switch(dataType) { + case DataType::Float16: + return float16FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); case DataType::Float32: - return floatFuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); + return float32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); case DataType::QuantisedAsymm8: return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); default: @@ -42,6 +45,16 @@ bool FalseFunc(std::string* reasonIfUnsupported, Params&&... params) } template<typename ... Params> +bool FalseFuncF16(std::string* reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Layer is not supported with float16 data type"; + } + return false; +} + +template<typename ... Params> bool FalseFuncF32(std::string* reasonIfUnsupported, Params&&... params) { if (reasonIfUnsupported) @@ -61,4 +74,44 @@ bool FalseFuncU8(std::string* reasonIfUnsupported, Params&&... params) return false; } +template<typename ... Params> +bool FalseInputFuncF32(std::string* reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Layer is not supported with float32 data type input"; + } + return false; +} + +template<typename ... Params> +bool FalseInputFuncF16(std::string* reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Layer is not supported with float16 data type input"; + } + return false; +} + +template<typename ... Params> +bool FalseOutputFuncF32(std::string* reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Layer is not supported with float32 data type output"; + } + return false; +} + +template<typename ... Params> +bool FalseOutputFuncF16(std::string* reasonIfUnsupported, Params&&... params) +{ + if (reasonIfUnsupported) + { + *reasonIfUnsupported = "Layer is not supported with float16 data type output"; + } + return false; +} + } |