aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/EndToEndTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/EndToEndTestImpl.hpp36
1 files changed, 36 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
index 358f4e3fc2..4221f626da 100644
--- a/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/EndToEndTestImpl.hpp
@@ -766,4 +766,40 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend
BOOST_TEST(found != std::string::npos);
}
+inline void StridedSliceInvalidSliceEndToEndTest(std::vector<BackendId> backends)
+{
+ using namespace armnn;
+
+ // Create runtime in which test will run
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+ // build up the structure of the network
+ INetworkPtr net(INetwork::Create());
+
+ IConnectableLayer* input = net->AddInputLayer(0);
+
+ // Configure a strided slice with a stride the same size as the input but with a ShrinkAxisMask on the first
+ // dim of the output to make it too small to hold the specified slice.
+ StridedSliceDescriptor descriptor;
+ descriptor.m_Begin = {0, 0};
+ descriptor.m_End = {2, 3};
+ descriptor.m_Stride = {1, 1};
+ descriptor.m_BeginMask = 0;
+ descriptor.m_EndMask = 0;
+ descriptor.m_ShrinkAxisMask = 1;
+ IConnectableLayer* stridedSlice = net->AddStridedSliceLayer(descriptor);
+
+ IConnectableLayer* output0 = net->AddOutputLayer(0);
+
+ input->GetOutputSlot(0).Connect(stridedSlice->GetInputSlot(0));
+ stridedSlice->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
+
+ input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 2, 3 }, DataType::Float32));
+ stridedSlice->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 3 }, DataType::Float32));
+
+ // Attempt to optimize the network and check that the correct exception is thrown
+ BOOST_CHECK_THROW(Optimize(*net, backends, runtime->GetDeviceSpec()), armnn::LayerValidationException);
+}
+
} // anonymous namespace