aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp76
1 files changed, 76 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index a24a325b2d..cbaae4075c 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -631,6 +631,76 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
result = layerSupportObject->IsQuantizeSupported(input, output, reason);
break;
}
+ case LayerType::QuantizedLstm:
+ {
+ auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
+
+ // Inputs
+ const TensorInfo& input = OverrideDataType(
+ layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
+ const TensorInfo& previousCellStateIn = OverrideDataType(
+ layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
+ const TensorInfo& previousOutputIn = OverrideDataType(
+ layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
+
+ // Outputs
+ const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
+ const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
+
+ // QuantizedLstm parameters
+ const TensorInfo& inputToInputWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToForgetWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToCellWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToOutputWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
+
+ const TensorInfo& recurrentToInputWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToForgetWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToCellWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToOutputWeights = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
+
+ const TensorInfo& inputGateBias = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType);
+ const TensorInfo& forgetGateBias = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
+ const TensorInfo& cellBias = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType);
+ const TensorInfo& outputGateBias = OverrideDataType(
+ cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType);
+
+ QuantizedLstmInputParamsInfo paramsInfo;
+
+ paramsInfo.m_InputToInputWeights = &inputToInputWeights;
+ paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
+ paramsInfo.m_InputToCellWeights = &inputToCellWeights;
+ paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
+
+ paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ paramsInfo.m_InputGateBias = &inputGateBias;
+ paramsInfo.m_ForgetGateBias = &forgetGateBias;
+ paramsInfo.m_CellBias = &cellBias;
+ paramsInfo.m_OutputGateBias = &outputGateBias;
+
+ result = layerSupportObject->IsQuantizedLstmSupported(input,
+ previousCellStateIn,
+ previousOutputIn,
+ cellStateOut,
+ output,
+ paramsInfo,
+ reason);
+ break;
+ }
case LayerType::Division:
{
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
@@ -1109,6 +1179,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueD
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{