diff options
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/ClTensorHandle.hpp | 20 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerSupportTests.cpp | 11 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClDequantizeWorkload.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClQuantizeWorkload.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloadUtils.hpp | 1 |
6 files changed, 35 insertions, 3 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index cf2b44ac55..1830d186b6 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -100,6 +100,11 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<uint8_t*>(memory)); break; + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int8_t*>(memory)); + break; case arm_compute::DataType::F16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<armnn::Half*>(memory)); @@ -141,6 +146,11 @@ private: this->GetTensor()); break; case arm_compute::DataType::S16: + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), + this->GetTensor()); + break; case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), this->GetTensor()); @@ -224,6 +234,11 @@ private: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), static_cast<armnn::Half*>(memory)); break; + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<int8_t*>(memory)); + break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), @@ -260,6 +275,11 @@ private: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), this->GetTensor()); break; + case arm_compute::DataType::QSYMM8_PER_CHANNEL: + case arm_compute::DataType::QASYMM8_SIGNED: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), + this->GetTensor()); + break; case arm_compute::DataType::S16: case arm_compute::DataType::QSYMM16: armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 0440aac022..4bb2e2a8ce 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -438,7 +438,7 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePrelu(const PreluQueueDescri std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClQuantizeWorkload, NullWorkload>(descriptor, info); + return MakeWorkload<ClQuantizeWorkload>(descriptor, info); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, diff --git a/src/backends/cl/test/ClLayerSupportTests.cpp b/src/backends/cl/test/ClLayerSupportTests.cpp index 8d10375778..33a2912b79 100644 --- a/src/backends/cl/test/ClLayerSupportTests.cpp +++ b/src/backends/cl/test/ClLayerSupportTests.cpp @@ -36,14 +36,21 @@ BOOST_FIXTURE_TEST_CASE(IsLayerSupportedFloat32Cl, ClContextControlFixture) IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::Float32>(&factory); } -BOOST_FIXTURE_TEST_CASE(IsLayerSupportedUint8Cl, ClContextControlFixture) +BOOST_FIXTURE_TEST_CASE(IsLayerSupportedQAsymmU8Cl, ClContextControlFixture) { armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::QAsymmU8>(&factory); } -BOOST_FIXTURE_TEST_CASE(IsLayerSupportedInt8Cl, ClContextControlFixture) +BOOST_FIXTURE_TEST_CASE(IsLayerSupportedQAsymmS8Cl, ClContextControlFixture) +{ + armnn::ClWorkloadFactory factory = + ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); + IsLayerSupportedTests<armnn::ClWorkloadFactory, armnn::DataType::QAsymmS8>(&factory); +} + +BOOST_FIXTURE_TEST_CASE(IsLayerSupportedQSymmS8Cl, ClContextControlFixture) { armnn::ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); diff --git a/src/backends/cl/workloads/ClDequantizeWorkload.cpp b/src/backends/cl/workloads/ClDequantizeWorkload.cpp index 67a555a020..eca795de7e 100644 --- a/src/backends/cl/workloads/ClDequantizeWorkload.cpp +++ b/src/backends/cl/workloads/ClDequantizeWorkload.cpp @@ -32,6 +32,8 @@ ClDequantizeWorkload::ClDequantizeWorkload(const DequantizeQueueDescriptor& desc const WorkloadInfo& workloadInfo) : BaseWorkload<DequantizeQueueDescriptor>(descriptor, workloadInfo) { + m_Data.ValidateInputsOutputs("ClDequantizeWorkload", 1, 1); + arm_compute::ICLTensor& input = boost::polymorphic_pointer_downcast<IClTensorHandle>( m_Data.m_Inputs[0])->GetTensor(); diff --git a/src/backends/cl/workloads/ClQuantizeWorkload.cpp b/src/backends/cl/workloads/ClQuantizeWorkload.cpp index 230e346a00..263065a5a4 100644 --- a/src/backends/cl/workloads/ClQuantizeWorkload.cpp +++ b/src/backends/cl/workloads/ClQuantizeWorkload.cpp @@ -32,6 +32,8 @@ arm_compute::Status ClQuantizeWorkloadValidate(const TensorInfo& input, ClQuantizeWorkload::ClQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload<QuantizeQueueDescriptor>(descriptor, info) { + m_Data.ValidateInputsOutputs("ClQuantizeWorkload", 1, 1); + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp index d3c6df50ed..b4bcc1c017 100644 --- a/src/backends/cl/workloads/ClWorkloadUtils.hpp +++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp @@ -101,6 +101,7 @@ inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor, case DataType::Float32: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<float>()); break; + case DataType::QAsymmS8: case DataType::QAsymmU8: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>()); break; |