diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index a24a325b2d..cbaae4075c 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -631,6 +631,76 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, result = layerSupportObject->IsQuantizeSupported(input, output, reason); break; } + case LayerType::QuantizedLstm: + { + auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer); + + // Inputs + const TensorInfo& input = OverrideDataType( + layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& previousCellStateIn = OverrideDataType( + layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& previousOutputIn = OverrideDataType( + layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType); + + // Outputs + const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType); + + // QuantizedLstm parameters + const TensorInfo& inputToInputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToForgetWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToCellWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToOutputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); + + const TensorInfo& recurrentToInputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToForgetWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToCellWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToOutputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); + + const TensorInfo& inputGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType); + const TensorInfo& forgetGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType); + const TensorInfo& cellBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType); + const TensorInfo& outputGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType); + + QuantizedLstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToInputWeights = &inputToInputWeights; + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + + paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + paramsInfo.m_InputGateBias = &inputGateBias; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + result = layerSupportObject->IsQuantizedLstmSupported(input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + output, + paramsInfo, + reason); + break; + } case LayerType::Division: { const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); @@ -1109,6 +1179,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueD return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { |