aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp80
1 files changed, 31 insertions, 49 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index dca5778e0e..1f616f0b18 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -639,61 +639,43 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
// Inputs
- const TensorInfo& input = OverrideDataType(
- layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
- const TensorInfo& previousCellStateIn = OverrideDataType(
- layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
- const TensorInfo& previousOutputIn = OverrideDataType(
- layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
// Outputs
- const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
- const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
+ const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
// QuantizedLstm parameters
- const TensorInfo& inputToInputWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
- const TensorInfo& inputToForgetWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
- const TensorInfo& inputToCellWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
- const TensorInfo& inputToOutputWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
-
- const TensorInfo& recurrentToInputWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
- const TensorInfo& recurrentToForgetWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
- const TensorInfo& recurrentToCellWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
- const TensorInfo& recurrentToOutputWeights = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
-
- const TensorInfo& inputGateBias = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), dataType);
- const TensorInfo& forgetGateBias = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
- const TensorInfo& cellBias = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), dataType);
- const TensorInfo& outputGateBias = OverrideDataType(
- cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), dataType);
-
QuantizedLstmInputParamsInfo paramsInfo;
- paramsInfo.m_InputToInputWeights = &inputToInputWeights;
- paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
- paramsInfo.m_InputToCellWeights = &inputToCellWeights;
- paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
-
- paramsInfo.m_RecurrentToInputWeights = &recurrentToInputWeights;
- paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
- paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
- paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
-
- paramsInfo.m_InputGateBias = &inputGateBias;
- paramsInfo.m_ForgetGateBias = &forgetGateBias;
- paramsInfo.m_CellBias = &cellBias;
- paramsInfo.m_OutputGateBias = &outputGateBias;
+ paramsInfo.m_InputToInputWeights =
+ &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
+ paramsInfo.m_InputToForgetWeights =
+ &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
+ paramsInfo.m_InputToCellWeights =
+ &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
+ paramsInfo.m_InputToOutputWeights =
+ &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
+
+ paramsInfo.m_RecurrentToInputWeights =
+ &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToForgetWeights =
+ &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToCellWeights =
+ &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
+ paramsInfo.m_RecurrentToOutputWeights =
+ &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
+
+ paramsInfo.m_InputGateBias =
+ &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
+ paramsInfo.m_ForgetGateBias =
+ &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
+ paramsInfo.m_CellBias =
+ &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
+ paramsInfo.m_OutputGateBias =
+ &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
result = layerSupportObject->IsQuantizedLstmSupported(input,
previousCellStateIn,