diff options
Diffstat (limited to 'src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp')
-rw-r--r-- | src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp index 3c3abc2af3..051965f541 100644 --- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp +++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp @@ -61,6 +61,58 @@ TEST_CASE("IsLayerSupportedTosaReferenceAdditionUnsupported") CHECK(!supported); } +TEST_CASE("IsLayerSupportedTosaReferenceConcat") +{ + TensorShape input0Shape = { 2, 3, 2, 2 }; + TensorShape input1Shape = { 2, 3, 2, 2 }; + TensorShape outputShape = { 2, 6, 2, 2 }; + TensorInfo input0Info(input0Shape, DataType::Float32); + TensorInfo input1Info(input1Shape, DataType::Float32); + TensorInfo outputInfo(outputShape, DataType::Float32); + + OriginsDescriptor descriptor; + std::vector<TensorShape> shapes = {input0Shape, input1Shape} ; + unsigned int concatAxis = 1; + descriptor = CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), concatAxis); + + TosaRefLayerSupport supportChecker; + std::string reasonIfNotSupported; + auto supported = supportChecker.IsLayerSupported(LayerType::Concat, + {input0Info, input1Info, outputInfo}, + descriptor, + EmptyOptional(), + EmptyOptional(), + reasonIfNotSupported); + + CHECK(supported); +} + +TEST_CASE("IsLayerSupportedTosaReferenceConcatUnsupported") +{ + TensorShape input0Shape = { 2, 3, 2, 2 }; + TensorShape input1Shape = { 2, 3, 2, 2 }; + TensorShape outputShape = { 2, 6, 2, 2 }; + TensorInfo input0Info(input0Shape, armnn::DataType::QAsymmU8); + TensorInfo input1Info(input1Shape, armnn::DataType::QAsymmU8); + TensorInfo outputInfo(outputShape, armnn::DataType::QAsymmU8); + + OriginsDescriptor descriptor; + std::vector<armnn::TensorShape> shapes = {input0Shape, input1Shape} ; + unsigned int concatAxis = 1; + descriptor = armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), concatAxis); + + TosaRefLayerSupport supportChecker; + std::string reasonIfNotSupported; + auto supported = supportChecker.IsLayerSupported(LayerType::Concat, + {input0Info, input1Info, outputInfo}, + descriptor, + EmptyOptional(), + EmptyOptional(), + reasonIfNotSupported); + + CHECK(!supported); +} + TEST_CASE("IsLayerSupportedTosaReferenceConstant") { TensorInfo outputInfo({1,1,3,4}, DataType::Float32); |