aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp5
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp9
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp5
3 files changed, 13 insertions, 6 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 9a691a6fa7..a9cddfd1bb 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1021,11 +1021,12 @@ bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
{
ignore_unused(descriptor);
// Define supported output types.
- std::array<DataType,3> supportedOutputTypes =
+ std::array<DataType,4> supportedOutputTypes =
{
DataType::Float32,
DataType::Float16,
- DataType::QuantisedAsymm8
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
};
return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
"Reference reshape: input type not supported.");
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 2222a22cb3..fef2567d07 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -661,16 +661,21 @@ static void RefCreateReshapeWorkloadTest()
TensorInfo({ 1, 4 }, DataType));
}
-BOOST_AUTO_TEST_CASE(CreateReshapeFloat32Workload)
+BOOST_AUTO_TEST_CASE(CreateReshapeWorkloadFloat32)
{
RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::Float32>();
}
-BOOST_AUTO_TEST_CASE(CreateReshapeUint8Workload)
+BOOST_AUTO_TEST_CASE(CreateReshapeWorkloadQuantisedAsymm8)
{
RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QuantisedAsymm8>();
}
+BOOST_AUTO_TEST_CASE(CreateReshapeWorkloadQuantisedSymm16)
+{
+ RefCreateReshapeWorkloadTest<RefReshapeWorkload, armnn::DataType::QuantisedSymm16>();
+}
+
template <typename ConcatWorkloadType, armnn::DataType DataType>
static void RefCreateConcatWorkloadTest(const armnn::TensorShape& outputShape,
unsigned int concatAxis)
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 1207c1d648..a78b5aae5d 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -449,8 +449,9 @@ ARMNN_AUTO_TEST_CASE(Concatenation4dDiffShapeDim3Uint8, Concatenation4dDiffShape
ARMNN_AUTO_TEST_CASE(SimpleFloor, SimpleFloorTest)
// Reshape
-ARMNN_AUTO_TEST_CASE(SimpleReshapeFloat32, SimpleReshapeFloat32Test)
-ARMNN_AUTO_TEST_CASE(SimpleReshapeUint8, SimpleReshapeUint8Test)
+ARMNN_AUTO_TEST_CASE(SimpleReshapeFloat32, SimpleReshapeTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(SimpleReshapeQuantisedAsymm8, SimpleReshapeTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(SimpleReshapeQuantisedSymm16, SimpleReshapeTest<armnn::DataType::QuantisedSymm16>)
// Rsqrt
ARMNN_AUTO_TEST_CASE(Rsqrt2d, Rsqrt2dTest)