From f982deaefbe5fe5814487b27f7099829839b8666 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Fri, 11 Oct 2019 14:07:53 +0100 Subject: IVGCVSW-3973 Add frontend for LOG_SOFTMAX Signed-off-by: Aron Virginas-Tar Change-Id: Ic6acc7176deea3753b32ce6340f642d19dce0e9f --- Android.mk | 1 + CMakeLists.txt | 2 + include/armnn/Descriptors.hpp | 3 ++ include/armnn/DescriptorsFwd.hpp | 12 +++--- include/armnn/ILayerSupport.hpp | 5 +++ include/armnn/ILayerVisitor.hpp | 8 ++++ include/armnn/INetwork.hpp | 7 +++ include/armnn/LayerVisitorBase.hpp | 4 ++ src/armnn/InternalTypes.cpp | 1 + src/armnn/InternalTypes.hpp | 1 + src/armnn/LayersFwd.hpp | 2 + src/armnn/Network.cpp | 6 +++ src/armnn/Network.hpp | 3 ++ src/armnn/layers/LogSoftmaxLayer.cpp | 50 ++++++++++++++++++++++ src/armnn/layers/LogSoftmaxLayer.hpp | 44 +++++++++++++++++++ .../test/TestNameAndDescriptorLayerVisitor.cpp | 28 ++++++++++++ .../test/TestNameAndDescriptorLayerVisitor.hpp | 26 +++++++++++ src/armnnSerializer/Serializer.cpp | 13 +++++- src/armnnSerializer/Serializer.hpp | 4 ++ src/backends/backendsCommon/LayerSupportBase.cpp | 8 ++++ src/backends/backendsCommon/LayerSupportBase.hpp | 5 +++ src/backends/backendsCommon/WorkloadData.cpp | 24 +++++++++-- src/backends/backendsCommon/WorkloadData.hpp | 5 +++ src/backends/backendsCommon/WorkloadFactory.cpp | 19 ++++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 3 ++ .../test/IsLayerSupportedTestImpl.hpp | 2 + 26 files changed, 276 insertions(+), 10 deletions(-) create mode 100644 src/armnn/layers/LogSoftmaxLayer.cpp create mode 100644 src/armnn/layers/LogSoftmaxLayer.hpp diff --git a/Android.mk b/Android.mk index 108e01107a..c27b707f12 100644 --- a/Android.mk +++ b/Android.mk @@ -140,6 +140,7 @@ LOCAL_SRC_FILES := \ src/armnn/layers/InputLayer.cpp \ src/armnn/layers/InstanceNormalizationLayer.cpp \ src/armnn/layers/L2NormalizationLayer.cpp \ + src/armnn/layers/LogSoftmaxLayer.cpp \ src/armnn/layers/LstmLayer.cpp \ src/armnn/layers/MaximumLayer.cpp \ src/armnn/layers/MeanLayer.cpp \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 0430643494..e69d29c5aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -284,6 +284,8 @@ list(APPEND armnn_sources src/armnn/layers/InstanceNormalizationLayer.cpp src/armnn/layers/L2NormalizationLayer.hpp src/armnn/layers/L2NormalizationLayer.cpp + src/armnn/layers/LogSoftmaxLayer.hpp + src/armnn/layers/LogSoftmaxLayer.cpp src/armnn/layers/LstmLayer.cpp src/armnn/layers/LstmLayer.hpp src/armnn/layers/MaximumLayer.cpp diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 5bf4043afa..e2e59741a3 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -74,6 +74,9 @@ struct SoftmaxDescriptor int m_Axis; }; +/// A LogSoftmaxDescriptor for the LogSoftmaxLayer +using LogSoftmaxDescriptor = SoftmaxDescriptor; + /// @brief An OriginsDescriptor for the ConcatLayer. /// Descriptor to configure the concatenation process. Number of views must be equal to the number of inputs, and /// their order must match - e.g. first view corresponds to the first input, second view to the second input, etc. diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 2cc95828e6..6f1c0e0a6e 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -7,6 +7,7 @@ namespace armnn { + struct ActivationDescriptor; struct ArgMinMaxDescriptor; struct BatchNormalizationDescriptor; @@ -38,10 +39,11 @@ struct StridedSliceDescriptor; struct TransposeConvolution2dDescriptor; struct ViewsDescriptor; +using ConcatDescriptor = OriginsDescriptor; using DepthToSpaceDescriptor = SpaceToDepthDescriptor; +using LogSoftmaxDescriptor = SoftmaxDescriptor; +// MergerDescriptor is deprecated, use ConcatDescriptor instead +using MergerDescriptor = OriginsDescriptor; +using SplitterDescriptor = ViewsDescriptor; -// MergerDescriptor is deprecated use ConcatDescriptor instead -using MergerDescriptor = OriginsDescriptor; -using ConcatDescriptor = OriginsDescriptor; -using SplitterDescriptor = ViewsDescriptor; -} +} // namespace armnn diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index fef7595b54..31b5e134e9 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -168,6 +168,11 @@ public: const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index b9c96d5448..e99e10f800 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -224,6 +224,14 @@ public: const L2NormalizationDescriptor& desc, const char* name = nullptr) = 0; + /// Function that a log softmax layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param logSoftmaxDescriptor - LogSoftmaxDescriptor to configure the log softmax. + /// @param name - Optional name for the layer. + virtual void VisitLogSoftmaxLayer(const IConnectableLayer* layer, + const LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name = nullptr) = 0; + /// Function an Lstm layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param descriptor - Parameters controlling the operation of the Lstm operation. diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index dc831db864..d12f5c239c 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -344,6 +344,13 @@ public: virtual IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name = nullptr) = 0; + /// Adds a log softmax layer to the network. + /// @param logSoftmaxDescriptor - LogSoftmaxDescriptor to configure the log softmax. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name = nullptr) = 0; + /// Adds a layer with no inputs and a single output, which always corresponds to /// the passed in constant tensor. /// @param input - Tensor to be provided as the only output of the layer. The layer will maintain diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 719e59d39c..912f25500c 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -120,6 +120,10 @@ public: const L2NormalizationDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitLogSoftmaxLayer(const IConnectableLayer*, + const LogSoftmaxDescriptor&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitLstmLayer(const IConnectableLayer*, const LstmDescriptor&, const LstmInputParams&, diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp index 612d00be5f..7c39128bec 100644 --- a/src/armnn/InternalTypes.cpp +++ b/src/armnn/InternalTypes.cpp @@ -40,6 +40,7 @@ char const* GetLayerTypeAsCString(LayerType type) case LayerType::Input: return "Input"; case LayerType::InstanceNormalization: return "InstanceNormalization"; case LayerType::L2Normalization: return "L2Normalization"; + case LayerType::LogSoftmax: return "LogSoftmax"; case LayerType::Lstm: return "Lstm"; case LayerType::Maximum: return "Maximum"; case LayerType::Mean: return "Mean"; diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp index 039d0f8ac8..895fe3235d 100644 --- a/src/armnn/InternalTypes.hpp +++ b/src/armnn/InternalTypes.hpp @@ -40,6 +40,7 @@ enum class LayerType Input, InstanceNormalization, L2Normalization, + LogSoftmax, Lstm, Maximum, Mean, diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp index 1f539f3076..7bb9c64818 100644 --- a/src/armnn/LayersFwd.hpp +++ b/src/armnn/LayersFwd.hpp @@ -32,6 +32,7 @@ #include "layers/InputLayer.hpp" #include "layers/InstanceNormalizationLayer.hpp" #include "layers/L2NormalizationLayer.hpp" +#include "layers/LogSoftmaxLayer.hpp" #include "layers/LstmLayer.hpp" #include "layers/MaximumLayer.hpp" #include "layers/MeanLayer.hpp" @@ -116,6 +117,7 @@ DECLARE_LAYER(Greater) DECLARE_LAYER(Input) DECLARE_LAYER(InstanceNormalization) DECLARE_LAYER(L2Normalization) +DECLARE_LAYER(LogSoftmax) DECLARE_LAYER(Lstm) DECLARE_LAYER(Maximum) DECLARE_LAYER(Mean) diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 9d10b9ace1..b2fc1a6389 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1236,6 +1236,12 @@ IConnectableLayer* Network::AddL2NormalizationLayer(const L2NormalizationDescrip return m_Graph->AddLayer(desc, name); } +IConnectableLayer* Network::AddLogSoftmaxLayer(const LogSoftmaxDescriptor& desc, + const char* name) +{ + return m_Graph->AddLayer(desc, name); +} + IConnectableLayer* Network::AddConstantLayer(const ConstTensor& input, const char* name) { auto layer = m_Graph->AddLayer(name); diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index e11f3d2185..ad1e7c456e 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -158,6 +158,9 @@ public: IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, const char* name = nullptr) override; + IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name = nullptr) override; + IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override; IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor, diff --git a/src/armnn/layers/LogSoftmaxLayer.cpp b/src/armnn/layers/LogSoftmaxLayer.cpp new file mode 100644 index 0000000000..6ca15b2d6f --- /dev/null +++ b/src/armnn/layers/LogSoftmaxLayer.cpp @@ -0,0 +1,50 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "LogSoftmaxLayer.hpp" + +#include "LayerCloneBase.hpp" + +#include + +#include +#include + +namespace armnn +{ + +LogSoftmaxLayer::LogSoftmaxLayer(const LogSoftmaxDescriptor ¶m, const char* name) + : LayerWithParameters(1, 1, LayerType::LogSoftmax, param, name) {} + +std::unique_ptr LogSoftmaxLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const +{ + LogSoftmaxQueueDescriptor descriptor; + return factory.CreateLogSoftmax(descriptor, PrepInfoAndDesc(descriptor, graph)); +} + +LogSoftmaxLayer* LogSoftmaxLayer::Clone(Graph& graph) const +{ + return CloneBase(graph, m_Param, GetName()); +} + +void LogSoftmaxLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(1, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + BOOST_ASSERT(inferredShapes.size() == 1); + + ConditionalThrowIfNotEqual( + "LogSoftmaxLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShapes[0]); +} + +void LogSoftmaxLayer::Accept(ILayerVisitor& visitor) const +{ + visitor.VisitLogSoftmaxLayer(this, GetParameters(), GetName()); +} + +} // namespace armnn diff --git a/src/armnn/layers/LogSoftmaxLayer.hpp b/src/armnn/layers/LogSoftmaxLayer.hpp new file mode 100644 index 0000000000..13da542139 --- /dev/null +++ b/src/armnn/layers/LogSoftmaxLayer.hpp @@ -0,0 +1,44 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "LayerWithParameters.hpp" + +namespace armnn +{ + +/// This layer represents a log softmax operation. +class LogSoftmaxLayer : public LayerWithParameters +{ +public: + /// Makes a workload for the LogSoftmax type. + /// @param [in] graph The graph where this layer can be found. + /// @param [in] factory The workload factory which will create the workload. + /// @return A pointer to the created workload, or nullptr if not created. + virtual std::unique_ptr CreateWorkload(const Graph& graph, + const IWorkloadFactory& factory) const override; + + /// Creates a dynamically-allocated copy of this layer. + /// @param [in] graph The graph into which this layer is being cloned. + LogSoftmaxLayer* Clone(Graph& graph) const override; + + /// Check if the input tensor shape(s) + /// will lead to a valid configuration of @ref LogSoftmaxLayer. + void ValidateTensorShapesFromInputs() override; + + void Accept(ILayerVisitor& visitor) const override; + +protected: + /// Constructor to create a LogSoftmaxLayer. + /// @param [in] param LogSoftmaxDescriptor to configure the log softmax operation. + /// @param [in] name Optional name for the layer. + LogSoftmaxLayer(const LogSoftmaxDescriptor& param, const char* name); + + /// Default destructor + ~LogSoftmaxLayer() = default; +}; + +} // namespace diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp index dcc5dc4cfb..e2bfb01733 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp @@ -328,6 +328,34 @@ BOOST_AUTO_TEST_CASE(CheckL2NormalizationLayerVisitorNameNullAndDescriptor) layer->Accept(visitor); } +BOOST_AUTO_TEST_CASE(CheckLogSoftmaxLayerVisitorNameAndDescriptor) +{ + const char* layerName = "LogSoftmaxLayer"; + + LogSoftmaxDescriptor descriptor; + descriptor.m_Beta = 2.0f; + descriptor.m_Axis = 1; + + TestLogSoftmaxLayerVisitor visitor(descriptor, layerName); + Network net; + + IConnectableLayer *const layer = net.AddLogSoftmaxLayer(descriptor, layerName); + layer->Accept(visitor); +} + +BOOST_AUTO_TEST_CASE(CheckLogSoftmaxLayerVisitorNameNullAndDescriptor) +{ + LogSoftmaxDescriptor descriptor; + descriptor.m_Beta = 2.0f; + descriptor.m_Axis = 1; + + TestLogSoftmaxLayerVisitor visitor(descriptor); + Network net; + + IConnectableLayer *const layer = net.AddLogSoftmaxLayer(descriptor); + layer->Accept(visitor); +} + BOOST_AUTO_TEST_CASE(CheckReshapeLayerVisitorNameAndDescriptor) { const char* layerName = "ReshapeLayer"; diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp index aa0b3597fa..e46aa34e29 100644 --- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp +++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp @@ -479,6 +479,32 @@ public: }; }; +class TestLogSoftmaxLayerVisitor : public TestLayerVisitor +{ +private: + LogSoftmaxDescriptor m_VisitorDescriptor; + +public: + explicit TestLogSoftmaxLayerVisitor(const LogSoftmaxDescriptor& descriptor, const char* name = nullptr) + : TestLayerVisitor(name) + , m_VisitorDescriptor(descriptor) {} + + void CheckDescriptor(const LogSoftmaxDescriptor& descriptor) + { + BOOST_CHECK_EQUAL(descriptor.m_Beta, m_VisitorDescriptor.m_Beta); + BOOST_CHECK_EQUAL(descriptor.m_Axis, m_VisitorDescriptor.m_Axis); + } + + void VisitLogSoftmaxLayer(const IConnectableLayer* layer, + const LogSoftmaxDescriptor& descriptor, + const char* name = nullptr) override + { + CheckLayerPointer(layer); + CheckDescriptor(descriptor); + CheckLayerName(name); + }; +}; + class TestReshapeLayerVisitor : public TestLayerVisitor { private: diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 5949d1d9fa..0e8c894e46 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -463,8 +463,17 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer); } -void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor, - const armnn::LstmInputParams& params, const char* name) +void SerializerVisitor::VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer, + const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name) +{ + throw armnn::UnimplementedException("SerializerVisitor::VisitLogSoftmaxLayer() is not implemented"); +} + +void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, + const armnn::LstmDescriptor& descriptor, + const armnn::LstmInputParams& params, + const char* name) { auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm); diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index f98bd17895..8c13245aeb 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -133,6 +133,10 @@ public: const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, const char* name = nullptr) override; + void VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer, + const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name = nullptr) override; + void VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor, const armnn::LstmInputParams& params, diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index c41f0b11ea..7d5555ce68 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -250,6 +250,14 @@ bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 495870e645..cb660f5c2b 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -152,6 +152,11 @@ public: const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ea0e5c82b8..b8d4f0dfff 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1294,8 +1294,6 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); - ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } @@ -1326,8 +1324,28 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); +} + +void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"LogSoftmaxQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + std::vector supportedTypes = + { + DataType::Float32, + DataType::Float16, + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 1bf3aa7509..5a3600fc71 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -317,6 +317,11 @@ struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct ConstantQueueDescriptor : QueueDescriptor { ConstantQueueDescriptor() diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 98fe158fc5..f19b48491a 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -401,6 +401,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::LogSoftmax: + { + auto cLayer = boost::polymorphic_downcast(&layer); + + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + cLayer->GetParameters(), + reason); + break; + } case LayerType::Lstm: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1167,6 +1180,12 @@ std::unique_ptr IWorkloadFactory::CreateL2Normalization(const L2Norma return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + std::unique_ptr IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 9fa0221f31..fa7a9d46a8 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -127,6 +127,9 @@ public: virtual std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index c8604140ec..907285c5cf 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -439,6 +439,8 @@ DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization) DECLARE_LAYER_POLICY_2_PARAM(L2Normalization) +DECLARE_LAYER_POLICY_2_PARAM(LogSoftmax) + DECLARE_LAYER_POLICY_2_PARAM(Lstm) DECLARE_LAYER_POLICY_1_PARAM(Maximum) -- cgit v1.2.1