aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2020-03-20 08:49:33 +0000
committerJames Conroy <james.conroy@arm.com>2020-03-20 14:53:44 +0000
commit586a9aac99312eb9cb304cbbd18cec46b9158e23 (patch)
tree6d620eae6dcfb920ac04eae43424548dc602a1eb /src/backends/backendsCommon/WorkloadFactory.cpp
parentc94d3f7107b84b586791aa096f8641e6efa18c90 (diff)
downloadarmnn-586a9aac99312eb9cb304cbbd18cec46b9158e23.tar.gz
IVGCVSW-4549 Add front end for new QLSTM layer
* Added new layer QLstm (Android R HAL 1.3) * Made necessary updates to APIs * Added unit tests * This layer is functionally equivalent to the original unquantized LSTM layer with some additonal quantization features added. Due to this, original LstmParams are used for this layer. Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I5b7f2d2fb6e17e81573b41a31bc55f49ae79608f
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp94
1 files changed, 94 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 5854bece3c..40ab798ba2 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -749,6 +749,94 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
result = layerSupportObject->IsQuantizeSupported(input, output, reason);
break;
}
+ case LayerType::QLstm:
+ {
+ auto cLayer = boost::polymorphic_downcast<const QLstmLayer*>(&layer);
+ const QLstmDescriptor& descriptor = cLayer->GetParameters();
+
+ // Inputs
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
+
+ // Outputs
+ const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
+ const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
+
+ // Lstm parameters
+ LstmInputParamsInfo paramsInfo;
+
+ // Basic parameters
+ paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
+ paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
+ paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
+
+ paramsInfo.m_RecurrentToForgetWeights =
+ &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToCellWeights =
+ &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToOutputWeights =
+ &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
+
+ paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
+ paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
+ paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
+
+ if(!descriptor.m_CifgEnabled)
+ {
+ paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToInputWeights =
+ &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
+ paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
+ }
+
+ if(descriptor.m_ProjectionEnabled)
+ {
+ paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
+ paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
+ }
+
+ if(descriptor.m_PeepholeEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ paramsInfo.m_CellToInputWeights =
+ &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
+ }
+
+ paramsInfo.m_CellToForgetWeights =
+ &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
+ paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
+ }
+
+ if(descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ paramsInfo.m_InputLayerNormWeights =
+ &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
+ }
+
+ paramsInfo.m_ForgetLayerNormWeights =
+ &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
+ paramsInfo.m_CellLayerNormWeights =
+ &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
+ paramsInfo.m_OutputLayerNormWeights =
+ &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
+ }
+
+ result = layerSupportObject->IsQLstmSupported(input,
+ previousOutputIn,
+ previousCellStateIn,
+ outputStateOut,
+ cellStateOut,
+ output,
+ descriptor,
+ paramsInfo,
+ reason);
+ break;
+ }
case LayerType::QuantizedLstm:
{
auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
@@ -1387,6 +1475,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueD
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
+ const WorkloadInfo& /*info*/) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
const WorkloadInfo& /*info*/) const
{