aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp147
1 files changed, 147 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 1b05c4e0f4..2603371927 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1242,6 +1242,7 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
"Reference Lstm: input and outputStateOut types are mismatched");
supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
"Reference Lstm: input and cellStateOut types are mismatched");
+
supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference Lstm: input and output types are mismatched");
// check layer parameters
@@ -2288,4 +2289,150 @@ bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
+ const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& output,
+ const Optional<TensorInfo>& hiddenStateOutput,
+ const Optional<TensorInfo>& cellStateOutput,
+ const UnidirectionalSequenceLstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+ IgnoreUnused(paramsInfo);
+ IgnoreUnused(outputStateIn);
+ IgnoreUnused(cellStateIn);
+ bool supported = true;
+
+ if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
+ {
+ reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
+ "and cell state output are not supported at the moment.";
+ }
+
+ std::array<DataType, 1> supportedTypes =
+ {
+ DataType::Float32
+ };
+
+ std::array<DataType, 1> supportedWeightTypes =
+ {
+ DataType::Float32
+ };
+
+ // check inputs and outputs
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input is not a supported type.");
+ supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
+ // check layer parameters
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputToForgetWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputToOutputWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
+ "are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
+ "are mismatched");
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputToInputWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
+ "are mismatched");
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellToInputWeights "
+ "is not a supported type.");
+ }
+ }
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellToForgetWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellToOutputWeights "
+ "is not a supported type.");
+ }
+ if (descriptor.m_ProjectionEnabled)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: ProjectionWeights "
+ "is not a supported type.");
+ if (paramsInfo.m_ProjectionBias != nullptr)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: input and ProjectionBias types "
+ "are mismatched");
+ }
+ }
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputLayerNormWeights "
+ "is not a supported type.");
+ }
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellLayerNormWeights "
+ "is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes),
+ reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights "
+ "is not a supported type.");
+ }
+
+ return supported;
+}
+
} // namespace armnn