diff options
author | James Conroy <james.conroy@arm.com> | 2019-07-17 11:27:46 +0100 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-07-24 10:40:13 +0100 |
commit | ee18dc8d1725f472850ab0c398fd7cbc4b850891 (patch) | |
tree | b57738b18781d512f5438ca5154652571393e4e8 /src/backends/backendsCommon/WorkloadFactory.cpp | |
parent | 7b1845206d723a91aec811edaf7cb0cf832dfd25 (diff) | |
download | armnn-ee18dc8d1725f472850ab0c398fd7cbc4b850891.tar.gz |
IVGCVSW-3469 Add front end for Quantized LSTM layer
* Added new layer QuantizedLstm (Android Q)
* Made necessary changes to APIs
* Added unit tests
Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a
Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index a24a325b2d..cbaae4075c 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -631,6 +631,76 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, result = layerSupportObject->IsQuantizeSupported(input, output, reason); break; } + case LayerType::QuantizedLstm: + { + auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer); + + // Inputs + const TensorInfo& input = OverrideDataType( + layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& previousCellStateIn = OverrideDataType( + layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType); + const TensorInfo& previousOutputIn = OverrideDataType( + layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType); + + // Outputs + const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType); + + // QuantizedLstm parameters + const TensorInfo& inputToInputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToForgetWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToCellWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToOutputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); + + const TensorInfo& recurrentToInputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToForgetWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToCellWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToOutputWeights = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); + + const TensorInfo& inputGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType); + const TensorInfo& forgetGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType); + const TensorInfo& cellBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType); + const TensorInfo& outputGateBias = OverrideDataType( + cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType); + + QuantizedLstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToInputWeights = &inputToInputWeights; + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + + paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + + paramsInfo.m_InputGateBias = &inputGateBias; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + result = layerSupportObject->IsQuantizedLstmSupported(input, + previousCellStateIn, + previousOutputIn, + cellStateOut, + output, + paramsInfo, + reason); + break; + } case LayerType::Division: { const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); @@ -1109,6 +1179,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueD return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { |