diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 80 |
1 files changed, 31 insertions, 49 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index dca5778e0e..1f616f0b18 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -639,61 +639,43 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer); // Inputs - const TensorInfo& input = OverrideDataType( - layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); - const TensorInfo& previousCellStateIn = OverrideDataType( - layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType); - const TensorInfo& previousOutputIn = OverrideDataType( - layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo(); // Outputs - const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); - const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType); + const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo(); // QuantizedLstm parameters - const TensorInfo& inputToInputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToForgetWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToCellWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType); - const TensorInfo& inputToOutputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); - - const TensorInfo& recurrentToInputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToForgetWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToCellWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); - const TensorInfo& recurrentToOutputWeights = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); - - const TensorInfo& inputGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType); - const TensorInfo& forgetGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType); - const TensorInfo& cellBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType); - const TensorInfo& outputGateBias = OverrideDataType( - cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType); - QuantizedLstmInputParamsInfo paramsInfo; - paramsInfo.m_InputToInputWeights = &inputToInputWeights; - paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; - paramsInfo.m_InputToCellWeights = &inputToCellWeights; - paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; - - paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights; - paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; - paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; - paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; - - paramsInfo.m_InputGateBias = &inputGateBias; - paramsInfo.m_ForgetGateBias = &forgetGateBias; - paramsInfo.m_CellBias = &cellBias; - paramsInfo.m_OutputGateBias = &outputGateBias; + paramsInfo.m_InputToInputWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(); + paramsInfo.m_InputToForgetWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(); + paramsInfo.m_InputToCellWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(); + paramsInfo.m_InputToOutputWeights = + &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(); + + paramsInfo.m_RecurrentToInputWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToForgetWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToCellWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(); + paramsInfo.m_RecurrentToOutputWeights = + &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(); + + paramsInfo.m_InputGateBias = + &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(); + paramsInfo.m_ForgetGateBias = + &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(); + paramsInfo.m_CellBias = + &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(); + paramsInfo.m_OutputGateBias = + &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();; result = layerSupportObject->IsQuantizedLstmSupported(input, previousCellStateIn, |