diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 8 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 24 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 19 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp | 2 |
7 files changed, 63 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index c41f0b11ea..7d5555ce68 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -250,6 +250,14 @@ bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 495870e645..cb660f5c2b 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -152,6 +152,11 @@ public: const L2NormalizationDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ea0e5c82b8..b8d4f0dfff 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1294,8 +1294,6 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); - ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } @@ -1326,8 +1324,28 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); +} + +void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"LogSoftmaxQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16, + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 1bf3aa7509..5a3600fc71 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -317,6 +317,11 @@ struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters<L2Normaliz void Validate(const WorkloadInfo& workloadInfo) const; }; +struct LogSoftmaxQueueDescriptor : QueueDescriptorWithParameters<LogSoftmaxDescriptor> +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct ConstantQueueDescriptor : QueueDescriptor { ConstantQueueDescriptor() diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 98fe158fc5..f19b48491a 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -401,6 +401,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::LogSoftmax: + { + auto cLayer = boost::polymorphic_downcast<const LogSoftmaxLayer*>(&layer); + + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + cLayer->GetParameters(), + reason); + break; + } case LayerType::Lstm: { auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer); @@ -1167,6 +1180,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2Norma return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 9fa0221f31..fa7a9d46a8 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -127,6 +127,9 @@ public: virtual std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index c8604140ec..907285c5cf 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -439,6 +439,8 @@ DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization) DECLARE_LAYER_POLICY_2_PARAM(L2Normalization) +DECLARE_LAYER_POLICY_2_PARAM(LogSoftmax) + DECLARE_LAYER_POLICY_2_PARAM(Lstm) DECLARE_LAYER_POLICY_1_PARAM(Maximum) |