aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp33
1 files changed, 32 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index b74b6afeb3..8ef5985fb3 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -396,6 +396,10 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
const TensorInfo* projectionBias = nullptr;
const TensorInfo* cellToForgetWeights = nullptr;
const TensorInfo* cellToOutputWeights = nullptr;
+ const TensorInfo* inputLayerNormWeights = nullptr;
+ const TensorInfo* forgetLayerNormWeights = nullptr;
+ const TensorInfo* cellLayerNormWeights = nullptr;
+ const TensorInfo* outputLayerNormWeights = nullptr;
TensorInfo optInputToInputWeights;
TensorInfo optRecurrentToInputWeights;
@@ -405,6 +409,10 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
TensorInfo optProjectionBias;
TensorInfo optCellToForgetWeights;
TensorInfo optCellToOutputWeights;
+ TensorInfo optInputLayerNormWeights;
+ TensorInfo optForgetLayerNormWeights;
+ TensorInfo optCellLayerNormWeights;
+ TensorInfo optOutputLayerNormWeights;
if(!descriptor.m_CifgEnabled)
{
@@ -449,6 +457,25 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
cellToOutputWeights = &optCellToOutputWeights;
}
+ if(descriptor.m_LayerNormEnabled)
+ {
+ optInputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
+ inputLayerNormWeights = &optInputLayerNormWeights;
+
+ optForgetLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
+ forgetLayerNormWeights = &optForgetLayerNormWeights;
+
+ optCellLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
+ cellLayerNormWeights = &optCellLayerNormWeights;
+
+ optOutputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
+ outputLayerNormWeights = &optOutputLayerNormWeights;
+ }
+
result = layerSupportObject->IsLstmSupported(
input,
outputStateIn,
@@ -475,7 +502,11 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
projectionBias,
cellToForgetWeights,
cellToOutputWeights,
- reason);
+ reason,
+ inputLayerNormWeights,
+ forgetLayerNormWeights,
+ cellLayerNormWeights,
+ outputLayerNormWeights);
break;
}
case LayerType::Maximum: