diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 5854bece3c..40ab798ba2 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -749,6 +749,94 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, result = layerSupportObject->IsQuantizeSupported(input, output, reason); break; } + case LayerType::QLstm: + { + auto cLayer = boost::polymorphic_downcast<const QLstmLayer*>(&layer); + const QLstmDescriptor& descriptor = cLayer->GetParameters(); + + // Inputs + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo(); + + // Outputs + const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo(); + + // Lstm parameters + LstmInputParamsInfo paramsInfo; + + // Basic parameters + paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(); + paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(); + paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(); + + paramsInfo.m_RecurrentToForgetWeights = + &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToCellWeights = + &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToOutputWeights = + &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(); + + paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(); + paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(); + paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(); + + if(!descriptor.m_CifgEnabled) + { + paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToInputWeights = + &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(); + paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(); + } + + if(descriptor.m_ProjectionEnabled) + { + paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(); + paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(); + } + + if(descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + paramsInfo.m_CellToInputWeights = + &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(); + } + + paramsInfo.m_CellToForgetWeights = + &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(); + paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(); + } + + if(descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + paramsInfo.m_InputLayerNormWeights = + &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(); + } + + paramsInfo.m_ForgetLayerNormWeights = + &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(); + paramsInfo.m_CellLayerNormWeights = + &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(); + paramsInfo.m_OutputLayerNormWeights = + &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(); + } + + result = layerSupportObject->IsQLstmSupported(input, + previousOutputIn, + previousCellStateIn, + outputStateOut, + cellStateOut, + output, + descriptor, + paramsInfo, + reason); + break; + } case LayerType::QuantizedLstm: { auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer); @@ -1387,6 +1475,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueD return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/, + const WorkloadInfo& /*info*/) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const { |