diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-05-11 16:10:38 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-05-13 18:34:12 +0000 |
commit | c1f6b09cc22b6e2de3d9eb88aec1778d0308a2b3 (patch) | |
tree | 14a70415b6d17b42246798a293dfdd13cbe4c123 /src/backends/cl/test | |
parent | 0c32ccfcd0b770cdd4eeb9d778b7f72a233e229a (diff) | |
download | armnn-c1f6b09cc22b6e2de3d9eb88aec1778d0308a2b3.tar.gz |
IVGCVSW-4753 Refactor CL Softmax workload generalizing for different datatype
* Change ComputeSoftmaxAclAxis to work with int and uint axis
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ibbfa9ec7e2f0416e6885673212a767419c871cca
Diffstat (limited to 'src/backends/cl/test')
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index b09b26f9b3..b7522547d4 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -569,19 +569,41 @@ static void ClSoftmaxWorkloadTest() auto inputHandle = PolymorphicDowncast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); auto outputHandle = PolymorphicDowncast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); + armnn::TensorInfo tensorInfo({4, 1}, DataType); + if (DataType == armnn::DataType::QAsymmU8) + { + tensorInfo.SetQuantizationOffset(0); + tensorInfo.SetQuantizationScale(1.f / 256); + } + else if (DataType == armnn::DataType::QAsymmS8) + { + tensorInfo.SetQuantizationOffset(-128); + tensorInfo.SetQuantizationScale(1.f / 256); + } + BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {4, 1})); BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {4, 1})); } -BOOST_AUTO_TEST_CASE(CreateSoftmaxFloatWorkloadTest) +BOOST_AUTO_TEST_CASE(CreateSoftmaxFloat32WorkloadTest) { - ClSoftmaxWorkloadTest<ClSoftmaxFloatWorkload, armnn::DataType::Float32>(); + ClSoftmaxWorkloadTest<ClSoftmaxWorkload, armnn::DataType::Float32>(); } BOOST_AUTO_TEST_CASE(CreateSoftmaxFloat16WorkloadTest) { - ClSoftmaxWorkloadTest<ClSoftmaxFloatWorkload, armnn::DataType::Float16>(); + ClSoftmaxWorkloadTest<ClSoftmaxWorkload, armnn::DataType::Float16>(); +} + +BOOST_AUTO_TEST_CASE(CreateSoftmaxQAsymmU8Workload) +{ + ClSoftmaxWorkloadTest<ClSoftmaxWorkload, armnn::DataType::QAsymmU8>(); +} + +BOOST_AUTO_TEST_CASE(CreateSoftmaxQAsymmS8Workload) +{ + ClSoftmaxWorkloadTest<ClSoftmaxWorkload, armnn::DataType::QAsymmS8>(); } template <typename armnn::DataType DataType> |