diff options
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonLayerSupport.cpp | 16 | ||||
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 4 | ||||
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 3 |
3 files changed, 17 insertions, 6 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp index 2f83c8f82a..9db7354e9e 100644 --- a/src/backends/neon/NeonLayerSupport.cpp +++ b/src/backends/neon/NeonLayerSupport.cpp @@ -72,6 +72,7 @@ bool IsSupportedForDataTypeNeon(Optional<std::string&> reasonIfUnsupported, floatFuncPtr, uint8FuncPtr, &FalseFunc<>, + &FalseFunc<>, std::forward<Params>(params)...); } @@ -214,7 +215,8 @@ bool NeonLayerSupport::IsFloorSupported(const TensorInfo& input, &FalseFuncF16<>, &TrueFunc<>, &FalseFuncU8<>, - &FalseFuncI32<>); + &FalseFuncI32<>, + &FalseFuncU8<>); } bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, @@ -344,10 +346,14 @@ bool NeonLayerSupport::IsNormalizationSupported(const TensorInfo& input, bool NeonLayerSupport::IsOutputSupported(const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const { - return IsSupportedForDataTypeNeon(reasonIfUnsupported, - output.GetDataType(), - &TrueFunc<>, - &TrueFunc<>); + return IsNeonBackendSupported(reasonIfUnsupported) && + IsSupportedForDataTypeGeneric(reasonIfUnsupported, + output.GetDataType(), + &TrueFunc<>, + &TrueFunc<>, + &TrueFunc<>, + &FalseFuncI32<>, + &TrueFunc<>); } bool NeonLayerSupport::IsPermuteSupported(const TensorInfo& input, diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 7206b6fc5a..b972043827 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -94,6 +94,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<float*>(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<uint8_t*>(memory)); @@ -114,6 +115,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), this->GetTensor()); @@ -181,6 +183,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<float*>(memory)); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<uint8_t*>(memory)); @@ -201,6 +204,7 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), this->GetTensor()); break; + case arm_compute::DataType::U8: case arm_compute::DataType::QASYMM8: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), this->GetTensor()); diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index e7fac97c2c..e8a00d6b14 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -91,7 +91,8 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateInput(const InputQueueDesc std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkloadHelper<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info); + return MakeWorkloadHelper<CopyMemGenericWorkload, CopyMemGenericWorkload, + CopyMemGenericWorkload, NullWorkload, CopyMemGenericWorkload>(descriptor, info); } std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, |