aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-05 15:25:46 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-04-05 17:11:02 +0100
commiteff363d58992fb6384053259f9e1ee773f8cd4df (patch)
treee0bce8c4694ee15e016951f9168afbf9b75a9c79
parent1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4 (diff)
downloadarmnn-eff363d58992fb6384053259f9e1ee773f8cd4df.tar.gz
IVGCVSW-2914 Add Switch Layer and no-op factory method
Change-Id: I6a6ece708a49e8a97c83a3e7fec11c88af1e1cfa Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/armnn/ILayerSupport.hpp6
-rw-r--r--include/armnn/ILayerVisitor.hpp6
-rw-r--r--include/armnn/INetwork.hpp5
-rw-r--r--include/armnn/LayerSupport.hpp9
-rw-r--r--include/armnn/LayerVisitorBase.hpp3
-rw-r--r--src/armnn/InternalTypes.cpp1
-rw-r--r--src/armnn/InternalTypes.hpp3
-rw-r--r--src/armnn/LayerSupport.cpp11
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp5
-rw-r--r--src/armnn/Network.hpp2
-rw-r--r--src/armnn/layers/SwitchLayer.cpp60
-rw-r--r--src/armnn/layers/SwitchLayer.hpp42
-rw-r--r--src/armnnDeserializer/Deserializer.cpp26
-rw-r--r--src/armnnDeserializer/Deserializer.hpp1
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md1
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs10
-rw-r--r--src/armnnSerializer/Serializer.cpp8
-rw-r--r--src/armnnSerializer/Serializer.hpp3
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp50
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp9
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp6
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp19
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
30 files changed, 430 insertions, 108 deletions
diff --git a/Android.mk b/Android.mk
index 6d5a0faa66..cd26fa548e 100644
--- a/Android.mk
+++ b/Android.mk
@@ -127,6 +127,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/SplitterLayer.cpp \
src/armnn/layers/StridedSliceLayer.cpp \
src/armnn/layers/SubtractionLayer.cpp \
+ src/armnn/layers/SwitchLayer.cpp \
src/armnn/Descriptors.cpp \
src/armnn/Exceptions.cpp \
src/armnn/Graph.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d1fe635407..b297423904 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -279,6 +279,8 @@ list(APPEND armnn_sources
src/armnn/layers/StridedSliceLayer.hpp
src/armnn/layers/SubtractionLayer.cpp
src/armnn/layers/SubtractionLayer.hpp
+ src/armnn/layers/SwitchLayer.cpp
+ src/armnn/layers/SwitchLayer.hpp
src/armnn/BackendSettings.hpp
src/armnn/CompatibleTypes.hpp
src/armnn/Descriptors.cpp
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 1b75810aca..dc843029c5 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -257,6 +257,12 @@ public:
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
+ virtual bool IsSwitchSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output0,
+ const TensorInfo& output1,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
}; // class ILayerSupport
using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 3a4c39b7c6..eabad58366 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -341,6 +341,12 @@ public:
virtual void VisitSubtractionLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a switch layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param name - Optional name for the layer.
+ virtual void VisitSwitchLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) = 0;
+
};
} // namespace armnn
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 8243b39c36..a15ceb1c15 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -399,6 +399,11 @@ public:
/// @ return - Interface for configuring the layer.
virtual IConnectableLayer* AddGatherLayer(const char* name = nullptr) = 0;
+ /// Adds a switch layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddSwitchLayer(const char* name = nullptr) = 0;
+
virtual void Accept(ILayerVisitor& visitor) const = 0;
protected:
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index e23fdd0a75..c9fc264e0c 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -338,4 +338,13 @@ bool IsSubtractionSupported(const BackendId& backend,
const TensorInfo& output,
char* reasonIfUnsupported = nullptr,
size_t reasonIfUnsupportedMaxLength = 1024);
+
+/// Deprecated in favor of IBackend and ILayerSupport interfaces
+bool IsSwitchSupported(const BackendId& backend,
+ const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output0,
+ const TensorInfo& output1,
+ char* reasonIfUnsupported = nullptr,
+ size_t reasonIfUnsupportedMaxLength = 1024);
}
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index f4e0f438be..12eb225674 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -178,6 +178,9 @@ public:
void VisitGatherLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(); }
+
+ void VisitSwitchLayer(const IConnectableLayer*,
+ const char*) override { DefaultPolicy::Apply(); }
};
} //namespace armnn
diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp
index 93a4f94378..a811706dfe 100644
--- a/src/armnn/InternalTypes.cpp
+++ b/src/armnn/InternalTypes.cpp
@@ -57,6 +57,7 @@ char const* GetLayerTypeAsCString(LayerType type)
case LayerType::Splitter: return "Splitter";
case LayerType::StridedSlice: return "StridedSlice";
case LayerType::Subtraction: return "Subtraction";
+ case LayerType::Switch: return "Switch";
default:
BOOST_ASSERT_MSG(false, "Unknown layer type");
return "Unknown";
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 7c7c601d95..5765b5bcf1 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -57,9 +57,10 @@ enum class LayerType
SpaceToBatchNd,
Splitter,
StridedSlice,
+ Subtraction,
// Last layer goes here.
LastLayer,
- Subtraction = LastLayer
+ Switch = LastLayer
};
const char* GetLayerTypeAsCString(LayerType type);
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index bc6eec891b..320d9cef74 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -530,4 +530,15 @@ bool IsSubtractionSupported(const BackendId& backend,
FORWARD_LAYER_SUPPORT_FUNC(backend, IsSubtractionSupported, input0, input1, output);
}
+bool IsSwitchSupported(const BackendId& backend,
+ const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output0,
+ const TensorInfo& output1,
+ char* reasonIfUnsupported,
+ size_t reasonIfUnsupportedMaxLength)
+{
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsSwitchSupported, input0, input1, output0, output1);
+}
+
} // namespace armnn
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 0bd68e04af..31cfa66896 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -50,6 +50,7 @@
#include "layers/SplitterLayer.hpp"
#include "layers/StridedSliceLayer.hpp"
#include "layers/SubtractionLayer.hpp"
+#include "layers/SwitchLayer.hpp"
namespace armnn
{
@@ -122,5 +123,6 @@ DECLARE_LAYER(SpaceToBatchNd)
DECLARE_LAYER(Splitter)
DECLARE_LAYER(StridedSlice)
DECLARE_LAYER(Subtraction)
+DECLARE_LAYER(Switch)
}
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 73db2e88d7..c1462c090d 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -971,6 +971,11 @@ IConnectableLayer* Network::AddMergeLayer(const char* name)
return m_Graph->AddLayer<MergeLayer>(name);
}
+IConnectableLayer* Network::AddSwitchLayer(const char* name)
+{
+ return m_Graph->AddLayer<SwitchLayer>(name);
+}
+
void Network::Accept(ILayerVisitor& visitor) const
{
for (auto layer : GetGraph())
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index bb7b9eb6f4..660ca87d13 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -176,6 +176,8 @@ public:
IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
+
void Accept(ILayerVisitor& visitor) const override;
private:
diff --git a/src/armnn/layers/SwitchLayer.cpp b/src/armnn/layers/SwitchLayer.cpp
new file mode 100644
index 0000000000..eae6e0dfe2
--- /dev/null
+++ b/src/armnn/layers/SwitchLayer.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "SwitchLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+SwitchLayer::SwitchLayer(const char* name)
+ : Layer(2, 2, LayerType::Switch, name)
+{}
+
+std::unique_ptr<IWorkload> SwitchLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ SwitchQueueDescriptor descriptor;
+ return factory.CreateSwitch(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+SwitchLayer* SwitchLayer::Clone(Graph& graph) const
+{
+ return CloneBase<SwitchLayer>(graph, GetName());
+}
+
+void SwitchLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ BOOST_ASSERT_MSG(GetNumOutputSlots() == 2, "SwitchLayer: The layer should return 2 outputs.");
+
+ // Assuming first input is the Input and second input is the Constant
+ std::vector<TensorShape> inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "SwitchLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(1).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+void SwitchLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitSwitchLayer(this, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/SwitchLayer.hpp b/src/armnn/layers/SwitchLayer.hpp
new file mode 100644
index 0000000000..bfda8c2b1b
--- /dev/null
+++ b/src/armnn/layers/SwitchLayer.hpp
@@ -0,0 +1,42 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "Layer.hpp"
+
+namespace armnn
+{
+
+/// This layer calculates both true and false outputs for input.
+class SwitchLayer : public Layer
+{
+public:
+ /// Makes a workload for the Switch type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ SwitchLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref SwitchLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a SwitchLayer.
+ /// @param [in] name Optional name for the layer.
+ SwitchLayer(const char* name);
+
+ /// Default destructor
+ ~SwitchLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 09cdd7cad3..076072e888 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_SplitterLayer] = &Deserializer::ParseSplitter;
m_ParserFunctions[Layer_StridedSliceLayer] = &Deserializer::ParseStridedSlice;
m_ParserFunctions[Layer_SubtractionLayer] = &Deserializer::ParseSubtraction;
+ m_ParserFunctions[Layer_SwitchLayer] = &Deserializer::ParseSwitch;
}
Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
@@ -306,6 +307,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
case Layer::Layer_SubtractionLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base();
+ case Layer::Layer_SwitchLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base();
case Layer::Layer_NONE:
default:
throw ParseException(boost::str(
@@ -2108,4 +2111,27 @@ void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+void Deserializer::ParseSwitch(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_LOCATION();
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 2);
+
+ auto layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddSwitchLayer(layerName.c_str());
+
+ armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo);
+
+ armnn::TensorInfo output1TensorInfo = ToTensorInfo(outputs[1]);
+ layer->GetOutputSlot(1).SetTensorInfo(output1TensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
} // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index df983d9086..dfa5b06057 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -114,6 +114,7 @@ private:
void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
+ void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot);
void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 4e5610c569..770f7a88b0 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -41,5 +41,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Splitter
* StridedSlice
* Subtraction
+* Switch
More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 8b275b6f17..e8d72fc997 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -119,7 +119,8 @@ enum LayerType : uint {
Lstm = 34,
Quantize = 35,
Dequantize = 36,
- Merge = 37
+ Merge = 37,
+ Switch = 38
}
// Base layer table to be used as part of other layers
@@ -529,6 +530,10 @@ table MergeLayer {
base:LayerBase;
}
+table SwitchLayer {
+ base:LayerBase;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -567,7 +572,8 @@ union Layer {
LstmLayer,
QuantizeLayer,
DequantizeLayer,
- MergeLayer
+ MergeLayer,
+ SwitchLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index fe30c3eee5..74d0c435c6 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -893,6 +893,14 @@ void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* la
CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
}
+void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+ auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
+ auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
+
+ CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
+}
+
fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
const serializer::LayerType layerType)
{
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 775df83966..4a718378b5 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -191,6 +191,9 @@ public:
void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+
+ void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
+ const char* name = nullptr) override;
private:
/// Creates the Input Slots and Output Slots and LayerBase for the layer.
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index a8335e1e68..5b54bfd7be 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -41,5 +41,6 @@ The Arm NN SDK Serializer currently supports the following layers:
* Splitter
* StridedSlice
* Subtraction
+* Switch
More machine learning layers will be supported in future releases.
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index a1ef9eef59..2724ba4d35 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -2113,6 +2113,56 @@ BOOST_AUTO_TEST_CASE(SerializeSubtraction)
deserializedNetwork->Accept(verifier);
}
+BOOST_AUTO_TEST_CASE(SerializeSwitch)
+{
+ class SwitchLayerVerifier : public LayerVerifierBase
+ {
+ public:
+ SwitchLayerVerifier(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos)
+ : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+ void VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ {
+ VerifyNameAndConnections(layer, name);
+ }
+
+ void VisitConstantLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ConstTensor& input,
+ const char *name) override {}
+ };
+
+ const std::string layerName("switch");
+ const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32);
+
+ std::vector<float> constantData = GenerateRandomData<float>(info.GetNumElements());
+ armnn::ConstTensor constTensor(info, constantData);
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const constantLayer = network->AddConstantLayer(constTensor, "constant");
+ armnn::IConnectableLayer* const switchLayer = network->AddSwitchLayer(layerName.c_str());
+ armnn::IConnectableLayer* const trueOutputLayer = network->AddOutputLayer(0);
+ armnn::IConnectableLayer* const falseOutputLayer = network->AddOutputLayer(1);
+
+ inputLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(0));
+ constantLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(1));
+ switchLayer->GetOutputSlot(0).Connect(trueOutputLayer->GetInputSlot(0));
+ switchLayer->GetOutputSlot(1).Connect(falseOutputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(info);
+ constantLayer->GetOutputSlot(0).SetTensorInfo(info);
+ switchLayer->GetOutputSlot(0).SetTensorInfo(info);
+ switchLayer->GetOutputSlot(1).SetTensorInfo(info);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ SwitchLayerVerifier verifier(layerName, {info, info}, {info, info});
+ deserializedNetwork->Accept(verifier);
+}
+
BOOST_AUTO_TEST_CASE(SerializeDeserializeNonLinearNetwork)
{
class ConstantLayerVerifier : public LayerVerifierBase
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index fc2d502fbd..6cad7b93ab 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -397,4 +397,13 @@ bool LayerSupportBase::IsSubtractionSupported(const TensorInfo& input0,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsSwitchSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output0,
+ const TensorInfo& output1,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 7c38b67379..3c39f8919d 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -246,6 +246,12 @@ public:
const TensorInfo& input1,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
+ bool IsSwitchSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output0,
+ const TensorInfo& output1,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
};
} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 348c864863..b850a65acf 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -75,45 +75,23 @@ void ValidateTensorShapesMatch(const TensorInfo& first,
}
//---------------------------------------------------------------
-void ValidateNoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
{
- if (workloadInfo.m_InputTensorInfos.size() != 0)
+ if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
{
throw InvalidArgumentException(descName +
- ": Requires no inputs. " +
- to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided.");
- }
-}
-
-//---------------------------------------------------------------
-void ValidateSingleInput(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
- if (workloadInfo.m_InputTensorInfos.size() != 1)
- {
- throw InvalidArgumentException(descName +
- ": Requires exactly one input. " +
- to_string(workloadInfo.m_InputTensorInfos.size()) + " has been provided." );
- }
-}
-
-//---------------------------------------------------------------
-void ValidateTwoInputs(const WorkloadInfo& workloadInfo, std::string const& descName)
-{
- if (workloadInfo.m_InputTensorInfos.size() != 2)
- {
- throw InvalidArgumentException(descName +
- ": Requires exactly two workloadInfo.m_InputTensorInfos. " +
+ ": Requires exactly " + to_string(expectedSize) + "input(s). " +
to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
}
}
//---------------------------------------------------------------
-void ValidateSingleOutput(const WorkloadInfo& workloadInfo, std::string const& descName)
+void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
{
- if (workloadInfo.m_OutputTensorInfos.size() != 1)
+ if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
{
throw InvalidArgumentException(descName +
- ": Requires exactly one output. " +
+ ": Requires exactly " + to_string(expectedSize) + " output(s). " +
to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
}
}
@@ -242,6 +220,18 @@ void ValidateTensorQuantizationMultiplier(const TensorInfo& inputTensor1, const
}
}
+//---------------------------------------------------------------
+void ValidateDataTypes(const TensorInfo& info,
+ const std::vector<armnn::DataType>& supportedTypes,
+ std::string const& descName)
+{
+ auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
+ if (iterator == supportedTypes.end())
+ {
+ throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
+ }
+}
+
} //namespace
void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
@@ -254,8 +244,8 @@ void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
//---------------------------------------------------------------
void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "MemCopyQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MemCopyQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MemCopyQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MemCopyQueueDescriptor" , 1);
if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
{
@@ -299,8 +289,8 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ActivationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ActivationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ActivationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ActivationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"ActivationQueueDescriptor",
@@ -311,8 +301,8 @@ void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SoftmaxQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SoftmaxQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "SoftmaxQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
@@ -324,7 +314,7 @@ void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SplitterQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SplitterQueueDescriptor", 1);
if (workloadInfo.m_OutputTensorInfos.size() <= 0)
{
@@ -372,7 +362,7 @@ void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleOutput(workloadInfo, "MergerQueueDescriptor");
+ ValidateNumOutputs(workloadInfo, "MergerQueueDescriptor", 1);
if (m_Inputs.size() <= 0)
{
@@ -444,8 +434,8 @@ void MergerQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FullyConnectedQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FullyConnectedQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FullyConnectedQueueDescriptor", 2, "output");
if (!(workloadInfo.m_InputTensorInfos[0].GetNumDimensions() == 2 ||
@@ -487,8 +477,8 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
//---------------------------------------------------------------
void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "NormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "NormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "NormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "NormalizationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"NormalizationQueueDescriptor",
@@ -498,8 +488,8 @@ void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "AdditionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "AdditionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -513,8 +503,8 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
//---------------------------------------------------------------
void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MultiplicationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MultiplicationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -526,8 +516,8 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "BatchNormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "BatchNormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "BatchNormalizationQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"BatchNormalizationQueueDescriptor",
@@ -554,8 +544,8 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf
void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "Convolution2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "Convolution2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "Convolution2dQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Convolution2dQueueDescriptor", 4, "output");
@@ -580,8 +570,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DepthwiseConvolution2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DepthwiseConvolution2dQueueDescriptor", 1);
ValidateTensorNumDimensions(
workloadInfo.m_InputTensorInfos[0], "DepthwiseConvolution2dQueueDescriptor", 4, "input");
@@ -625,8 +615,8 @@ void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa
void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "PermuteQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "PermuteQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "PermuteQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "PermuteQueueDescriptor", 1);
const PermutationVector& mapping = m_Parameters.m_DimMappings;
@@ -650,8 +640,8 @@ void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "Pooling2dQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "Pooling2dQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "Pooling2dQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "Pooling2dQueueDescriptor", 4, "output");
@@ -659,8 +649,8 @@ void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ResizeBilinearQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ResizeBilinearQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ResizeBilinearQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "ResizeBilinearQueueDescriptor", 4, "output");
@@ -694,8 +684,8 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FakeQuantizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FakeQuantizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FakeQuantizationQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "FakeQuantizationQueueDescriptor", 2, "output");
@@ -713,8 +703,8 @@ void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "L2NormalizationQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "L2NormalizationQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "L2NormalizationQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "L2NormalizationQueueDescriptor", 4, "output");
@@ -727,8 +717,8 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo)
void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateNoInputs(workloadInfo, "ConstantQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConstantQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConstantQueueDescriptor", 0);
+ ValidateNumOutputs(workloadInfo, "ConstantQueueDescriptor", 1);
if (!m_LayerOutput)
{
@@ -744,8 +734,8 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ReshapeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ReshapeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ReshapeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ReshapeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements())
{
@@ -757,8 +747,8 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "SpaceToBatchNdQueueDescriptor", 1);
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output");
@@ -804,8 +794,8 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "FloorQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "FlootQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "FloorQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "FlootQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0] != workloadInfo.m_OutputTensorInfos[0])
{
@@ -821,8 +811,8 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConvertFp32ToFp16QueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ConvertFp32ToFp16QueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
{
@@ -843,8 +833,8 @@ void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo
void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
- ValidateSingleOutput(workloadInfo, "ConvertFp16ToFp32QueueDescriptor");
+ ValidateNumInputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "ConvertFp16ToFp32QueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float16)
{
@@ -864,8 +854,8 @@ void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo
void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "DivisionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DivisionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -877,8 +867,8 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "SubtractionQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "SubtractionQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -890,8 +880,8 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MaximumQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MaximumQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -903,8 +893,8 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MeanQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MeanQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -929,8 +919,8 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "PadQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "PadQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "PadQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "PadQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
@@ -948,8 +938,8 @@ void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "QuantizeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "QuantizeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "QuantizeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "QuantizeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::Float32)
@@ -966,14 +956,14 @@ void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "BatchToSpaceNdQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "BatchToSpaceNdQueueDescriptor", 1);
}
void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "StridedSliceQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "StridedSliceQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "StridedSliceQueueDescriptor", 1);
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const uint32_t rank = input.GetNumDimensions();
@@ -1015,8 +1005,8 @@ void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MinimumQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MinimumQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1028,14 +1018,14 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DebugQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DebugQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DebugQueueDescriptor", 1);
}
void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "EqualQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "EqualQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1052,8 +1042,8 @@ void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "GreaterQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "GreaterQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "GreaterQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "GreaterQueueDescriptor", 1);
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1070,8 +1060,8 @@ void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "RsqrtQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "RsqrtQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "RsqrtQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "RsqrtQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_OutputTensorInfos[0],
"RsqrtQueueDescriptor",
@@ -1081,8 +1071,8 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "GatherQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "GatherQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "GatherQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "GatherQueueDescriptor", 1);
const TensorInfo& indices = workloadInfo.m_InputTensorInfos[1];
@@ -1102,7 +1092,7 @@ void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "DetectionPostProcessQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2);
if (workloadInfo.m_OutputTensorInfos.size() != 4)
{
@@ -1155,8 +1145,8 @@ void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadI
void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateSingleInput(workloadInfo, "DequantizeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "DequantizeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "DequantizeQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "DequantizeQueueDescriptor", 1);
if (workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedAsymm8 &&
workloadInfo.m_InputTensorInfos[0].GetDataType() != DataType::QuantisedSymm16)
@@ -1172,8 +1162,8 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
- ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
- ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+ ValidateNumInputs(workloadInfo, "MergeQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "MergeQueueDescriptor", 1);
ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
@@ -1192,6 +1182,42 @@ void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
}
+void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "SwitchQueueDescriptor", 2);
+ ValidateNumOutputs(workloadInfo, "SwitchQueueDescriptor", 2);
+
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SwitchQueueDescriptor");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "SwitchQueueDescriptor",
+ "input0",
+ "output0");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[1],
+ "SwitchQueueDescriptor",
+ "input0",
+ "output1");
+}
+
void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
// This is internally generated so it should not need validation.
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 1bf735288d..1b5f86dde7 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -426,4 +426,9 @@ struct MergeQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct SwitchQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 4ea3ea9f9b..d9774b063d 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -729,6 +729,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Switch:
+ {
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
+ const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
+ result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output0, dataType),
+ OverrideDataType(output1, dataType),
+ reason);
+ break;
+ }
case LayerType::Mean:
{
auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
@@ -1041,4 +1054,10 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const Subtraction
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
}
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index 889bc9d595..5c07b3af6f 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -177,6 +177,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSwitch(const SwitchQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
};
} //namespace armnn
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 0588607a82..a7d7b094cf 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -402,6 +402,8 @@ DECLARE_LAYER_POLICY_2_PARAM(StridedSlice)
DECLARE_LAYER_POLICY_1_PARAM(Subtraction)
+DECLARE_LAYER_POLICY_1_PARAM(Switch)
+
// Generic implementation to get the number of input slots for a given layer type;
template<armnn::LayerType Type>