aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/LayerSupportBase.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/LayerSupportBase.cpp')
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp47
1 files changed, 46 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 220590e197..89a0772602 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -4,13 +4,13 @@
//
#include <armnn/Deprecated.hpp>
-#include <armnn/Descriptors.hpp>
#include <armnn/Exceptions.hpp>
#include <armnn/Types.hpp>
#include <backendsCommon/LayerSupportBase.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
+#include <armnn/utility/PolymorphicDowncast.hpp>
namespace
{
@@ -37,6 +37,51 @@ bool DefaultLayerSupport(const char* func,
namespace armnn
{
+bool LayerSupportBase::IsLayerSupported(const LayerType& type,
+ const std::vector<TensorInfo>& infos,
+ const BaseDescriptor& descriptor,
+ const Optional<LstmInputParamsInfo>&,
+ const Optional<QuantizedLstmInputParamsInfo>&,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ switch(type)
+ {
+ case LayerType::MemCopy:
+ return IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
+ case LayerType::MemImport:
+ return IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
+ case LayerType::StandIn:
+ {
+ auto desc = *(PolymorphicDowncast<const StandInDescriptor*>(&descriptor));
+
+ if (infos.size() != (desc.m_NumInputs + desc.m_NumOutputs))
+ {
+ throw InvalidArgumentException("Number of StandIn layer TensorInfos does not equal "
+ "the combined number of input and output slots assigned "
+ "to the StandIn descriptor");
+ }
+
+ std::vector<const TensorInfo*> inputInfos;
+ for (uint32_t i = 0; i < desc.m_NumInputs; i++)
+ {
+ inputInfos.push_back(&infos[i]);
+ }
+ std::vector<const TensorInfo*> outputInfos;
+ for (uint32_t i = desc.m_NumInputs; i < infos.size(); i++)
+ {
+ outputInfos.push_back(&infos[i]);
+ }
+
+ return IsStandInSupported(inputInfos,
+ outputInfos,
+ desc,
+ reasonIfUnsupported);
+ }
+ default:
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+ }
+}
+
bool LayerSupportBase::IsActivationSupported(const TensorInfo&, // input
const TensorInfo&, //output
const ActivationDescriptor&, // descriptor