From 2b4d88e34ac1f965417fd236fd4786f26bae2042 Mon Sep 17 00:00:00 2001 From: kevmay01 Date: Thu, 24 Jan 2019 14:05:09 +0000 Subject: IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater * Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32 --- src/backends/neon/NeonLayerSupport.cpp | 16 +++++++++++----- src/backends/neon/NeonTensorHandle.hpp | 4 ++++ src/backends/neon/NeonWorkloadFactory.cpp | 3 ++- 3 files changed, 17 insertions(+), 6 deletions(-) (limited to 'src/backends/neon') diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 2f83c8f82a..9db7354e9e 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -72,6 +72,7 @@ bool IsSupportedForDataTypeNeon(Optional reasonIfUnsupported, floatFuncPtr, uint8FuncPtr, &FalseFunc<>, + &FalseFunc<>, std::forward(params)...); } @@ -214,7 +215,8 @@ bool NeonLayerSupport::IsFloorSupported(const TensorInfo& input, &FalseFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>); + &FalseFuncI32<>, + &FalseFuncU8<>); } bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, @@ -344,10 +346,14 @@ bool NeonLayerSupport::IsNormalizationSupported(const TensorInfo& input, bool NeonLayerSupport::IsOutputSupported(const TensorInfo& output, Optional reasonIfUnsupported) const { - return IsSupportedForDataTypeNeon(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsNeonBackendSupported(reasonIfUnsupported) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool NeonLayerSupport::IsPermuteSupported(const TensorInfo& input, diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 7206b6fc5a..b972043827 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -94,6 +94,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); @@ -114,6 +115,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); @@ -181,6 +183,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast(memory)); @@ -201,6 +204,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), this->GetTensor()); diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index e7fac97c2c..e8a00d6b14 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -91,7 +91,8 @@ std::unique_ptr NeonWorkloadFactory::CreateInput(const InputQueueDesc std::unique_ptr NeonWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkloadHelper(descriptor, info); + return MakeWorkloadHelper(descriptor, info); } std::unique_ptr NeonWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, -- cgit v1.2.1