aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LayerSupportCommon.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/LayerSupportCommon.hpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/LayerSupportCommon.hpp')
-rw-r--r--src/armnn/LayerSupportCommon.hpp59
1 files changed, 56 insertions, 3 deletions
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp
index 5b7feac387..63065c0565 100644
--- a/src/armnn/LayerSupportCommon.hpp
+++ b/src/armnn/LayerSupportCommon.hpp
@@ -11,17 +11,20 @@
namespace armnn
{
-template<typename Float32Func, typename Uint8Func, typename ... Params>
+template<typename Float16Func, typename Float32Func, typename Uint8Func, typename ... Params>
bool IsSupportedForDataTypeGeneric(std::string* reasonIfUnsupported,
DataType dataType,
- Float32Func floatFuncPtr,
+ Float16Func float16FuncPtr,
+ Float32Func float32FuncPtr,
Uint8Func uint8FuncPtr,
Params&&... params)
{
switch(dataType)
{
+ case DataType::Float16:
+ return float16FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
case DataType::Float32:
- return floatFuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
+ return float32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
case DataType::QuantisedAsymm8:
return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...);
default:
@@ -42,6 +45,16 @@ bool FalseFunc(std::string* reasonIfUnsupported, Params&&... params)
}
template<typename ... Params>
+bool FalseFuncF16(std::string* reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Layer is not supported with float16 data type";
+ }
+ return false;
+}
+
+template<typename ... Params>
bool FalseFuncF32(std::string* reasonIfUnsupported, Params&&... params)
{
if (reasonIfUnsupported)
@@ -61,4 +74,44 @@ bool FalseFuncU8(std::string* reasonIfUnsupported, Params&&... params)
return false;
}
+template<typename ... Params>
+bool FalseInputFuncF32(std::string* reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Layer is not supported with float32 data type input";
+ }
+ return false;
+}
+
+template<typename ... Params>
+bool FalseInputFuncF16(std::string* reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Layer is not supported with float16 data type input";
+ }
+ return false;
+}
+
+template<typename ... Params>
+bool FalseOutputFuncF32(std::string* reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Layer is not supported with float32 data type output";
+ }
+ return false;
+}
+
+template<typename ... Params>
+bool FalseOutputFuncF16(std::string* reasonIfUnsupported, Params&&... params)
+{
+ if (reasonIfUnsupported)
+ {
+ *reasonIfUnsupported = "Layer is not supported with float16 data type output";
+ }
+ return false;
+}
+
}