From f982deaefbe5fe5814487b27f7099829839b8666 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Fri, 11 Oct 2019 14:07:53 +0100 Subject: IVGCVSW-3973 Add frontend for LOG_SOFTMAX Signed-off-by: Aron Virginas-Tar Change-Id: Ic6acc7176deea3753b32ce6340f642d19dce0e9f --- src/backends/backendsCommon/LayerSupportBase.cpp | 8 ++++++++ src/backends/backendsCommon/LayerSupportBase.hpp | 5 +++++ src/backends/backendsCommon/WorkloadData.cpp | 24 +++++++++++++++++++--- src/backends/backendsCommon/WorkloadData.hpp | 5 +++++ src/backends/backendsCommon/WorkloadFactory.cpp | 19 +++++++++++++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 3 +++ .../test/IsLayerSupportedTestImpl.hpp | 2 ++ 7 files changed, 63 insertions(+), 3 deletions(-) (limited to 'src/backends/backendsCommon') diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index c41f0b11ea..7d5555ce68 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -250,6 +250,14 @@ bool LayerSupportBase::IsL2NormalizationSupported(const TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 495870e645..cb660f5c2b 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -152,6 +152,11 @@ public: const L2NormalizationDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLogSoftmaxSupported(const TensorInfo& input, + const TensorInfo& output, + const LogSoftmaxDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index ea0e5c82b8..b8d4f0dfff 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1294,8 +1294,6 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); - ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } @@ -1326,8 +1324,28 @@ void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); +} + +void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"LogSoftmaxQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 1); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + std::vector supportedTypes = + { + DataType::Float32, + DataType::Float16, + }; + + ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 1bf3aa7509..5a3600fc71 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -317,6 +317,11 @@ struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct ConstantQueueDescriptor : QueueDescriptor { ConstantQueueDescriptor() diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 98fe158fc5..f19b48491a 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -401,6 +401,19 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::LogSoftmax: + { + auto cLayer = boost::polymorphic_downcast(&layer); + + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + cLayer->GetParameters(), + reason); + break; + } case LayerType::Lstm: { auto cLayer = boost::polymorphic_downcast(&layer); @@ -1167,6 +1180,12 @@ std::unique_ptr IWorkloadFactory::CreateL2Normalization(const L2Norma return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr(); +} + std::unique_ptr IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 9fa0221f31..fa7a9d46a8 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -127,6 +127,9 @@ public: virtual std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index c8604140ec..907285c5cf 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -439,6 +439,8 @@ DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization) DECLARE_LAYER_POLICY_2_PARAM(L2Normalization) +DECLARE_LAYER_POLICY_2_PARAM(LogSoftmax) + DECLARE_LAYER_POLICY_2_PARAM(Lstm) DECLARE_LAYER_POLICY_1_PARAM(Maximum) -- cgit v1.2.1