aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp')
-rw-r--r--src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp79
1 files changed, 11 insertions, 68 deletions
diff --git a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
index a1bab83e72..0d0cd6eefc 100644
--- a/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
+++ b/src/backends/tosaReference/test/TosaRefLayerSupportTests.cpp
@@ -20,7 +20,7 @@ TEST_SUITE("TosaRefLayerSupported")
TEST_CASE("IsLayerSupportedTosaReferenceAddition")
{
TensorShape shape0 = {1,1,3,4};
- TensorShape shape1 = {4};
+ TensorShape shape1 = {1,1,3,4};
TensorShape outShape = {1,1,3,4};
TensorInfo in0(shape0, DataType::Float32);
TensorInfo in1(shape1, DataType::Float32);
@@ -59,14 +59,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceAdditionUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_ADD for input: input0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_ADD for input: input1_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_ADD for output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported data type: DType_UNKNOWN") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceConstant")
@@ -99,10 +91,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceConstantUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_CONST for output: constant_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported data type: DType_UNKNOWN") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceConv2d")
@@ -148,14 +136,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceConv2dUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_CONV2D for input 0: input0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "input 1: input1_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "and output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported input data type combination.") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceMaxPooling2d")
@@ -166,6 +146,12 @@ TEST_CASE("IsLayerSupportedTosaReferenceMaxPooling2d")
TensorInfo out(outShape, DataType::Float32);
Pooling2dDescriptor desc;
+ desc.m_PoolHeight = 1;
+ desc.m_PoolWidth = 1;
+ desc.m_StrideX = 1;
+ desc.m_StrideY = 1;
+ desc.m_PoolType = armnn::PoolingAlgorithm::Max;
+
TosaRefLayerSupport supportChecker;
std::string reasonIfNotSupported;
auto supported = supportChecker.IsLayerSupported(LayerType::Pooling2d,
@@ -186,29 +172,10 @@ TEST_CASE("IsLayerSupportedTosaReferenceAvgPooling2d_IgnoreValue")
TensorInfo out(outShape, DataType::Float32);
Pooling2dDescriptor desc;
- desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
- desc.m_PoolType = PoolingAlgorithm::Average;
-
- TosaRefLayerSupport supportChecker;
- std::string reasonIfNotSupported;
- auto supported = supportChecker.IsLayerSupported(LayerType::Pooling2d,
- {in, out},
- desc,
- EmptyOptional(),
- EmptyOptional(),
- reasonIfNotSupported);
-
- CHECK(supported);
-}
-
-TEST_CASE("IsLayerSupportedTosaReferenceAvgPooling2d_InputOutputDatatypeDifferent")
-{
- TensorShape inShape = {1,1,3,4};
- TensorShape outShape = {1,1,3,4};
- TensorInfo in(inShape, DataType::QAsymmS8);
- TensorInfo out(outShape, DataType::Signed32);
-
- Pooling2dDescriptor desc;
+ desc.m_PoolHeight = 1;
+ desc.m_PoolWidth = 1;
+ desc.m_StrideX = 1;
+ desc.m_StrideY = 1;
desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
desc.m_PoolType = PoolingAlgorithm::Average;
@@ -242,12 +209,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceMaxPooling2dUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_MAX_POOL2D for input: input0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_MAX_POOL2D for output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported data type: DType_UNKNOWN") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceAvgPooling2dUnsupported_InputOutputDatatypeDifferent")
@@ -271,12 +232,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceAvgPooling2dUnsupported_InputOutputDatat
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_AVG_POOL2D for input: intermediate0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- " and output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- " has an unsupported input data type: DType_FP32 to output data type: DType_FP16") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceReshape")
@@ -321,12 +276,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceReshapeUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_RESHAPE for input: input0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_RESHAPE for output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported data type: DType_UNKNOWN") != std::string::npos);
}
TEST_CASE("IsLayerSupportedTosaReferenceSlice")
@@ -373,12 +322,6 @@ TEST_CASE("IsLayerSupportedTosaReferenceSliceUnsupported")
reasonIfNotSupported);
CHECK(!supported);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_SLICE for input: input0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "TOSA Reference Operator: Op_SLICE for output: output0_") != std::string::npos);
- REQUIRE(reasonIfNotSupported.find(
- "has an unsupported data type: DType_UNKNOWN") != std::string::npos);
}
}