diff options
author | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
---|---|---|
committer | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
commit | 2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch) | |
tree | 4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/armnn/LayerSupportCommon.hpp | |
parent | 94412aff782472be54dce4328e2ecee0225b3e97 (diff) | |
download | armnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz |
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload
* Create RefComparisonWorkload and add Equal and Greater
* Update ElementwiseFunction for different input/output types
* Update TfParser to create Equal/Greater with Boolean output
* Update relevant tests to check for Boolean comparison
Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/armnn/LayerSupportCommon.hpp')
-rw-r--r-- | src/armnn/LayerSupportCommon.hpp | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/src/armnn/LayerSupportCommon.hpp b/src/armnn/LayerSupportCommon.hpp index 109728cd81..70b5f182f4 100644 --- a/src/armnn/LayerSupportCommon.hpp +++ b/src/armnn/LayerSupportCommon.hpp @@ -12,13 +12,15 @@ namespace armnn { -template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename ... Params> +template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename BooleanFunc, + typename ... Params> bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported, DataType dataType, Float16Func float16FuncPtr, Float32Func float32FuncPtr, Uint8Func uint8FuncPtr, Int32Func int32FuncPtr, + BooleanFunc booleanFuncPtr, Params&&... params) { switch(dataType) @@ -31,6 +33,8 @@ bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported, return uint8FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); case DataType::Signed32: return int32FuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); + case DataType::Boolean: + return booleanFuncPtr(reasonIfUnsupported, std::forward<Params>(params)...); default: return false; } |