aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp112
1 files changed, 34 insertions, 78 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 66661cb521..919c6db6ff 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -465,57 +465,15 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
"hiddenStateOutputVal, cellStateOutputVal, output}");
}
auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
-
- bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
- bool isCellStateOutput = (infos[5] == TensorInfo());
- if (isHiddenStateOutputOptional && isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isHiddenStateOutputOptional)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
+ return IsUnidirectionalSequenceLstmSupported(infos[0],
+ infos[1],
+ infos[2],
+ infos[3],
+ infos[4],
+ infos[5],
+ desc,
+ lstmParamsInfo.value(),
+ reasonIfUnsupported);
}
case LayerType::Pooling3d:
return IsPooling3dSupported(infos[0],
@@ -2841,9 +2799,9 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported) const
@@ -2852,17 +2810,14 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
IgnoreUnused(paramsInfo);
IgnoreUnused(outputStateIn);
IgnoreUnused(cellStateIn);
+ IgnoreUnused(outputStateOut);
+ IgnoreUnused(cellStateOut);
bool supported = true;
- if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
+ std::array<DataType, 2> supportedTypes =
{
- reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
- "and cell state output are not supported at the moment.";
- }
-
- std::array<DataType, 1> supportedTypes =
- {
- DataType::Float32
+ DataType::Float32,
+ DataType::QAsymmS8
};
std::array<DataType, 2> supportedWeightTypes =
@@ -2871,16 +2826,19 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
DataType::QAsymmS8
};
+ std::array<DataType, 3> supportedBiasTypes =
+ {
+ DataType::Float32,
+ DataType::QAsymmS8,
+ DataType::Signed32
+ };
+
// check inputs and outputs
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: input is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: output is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
// check layer parameters
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
reasonIfUnsupported,
@@ -2905,14 +2863,13 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
"is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
- "are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
- "are mismatched");
+
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
if (!descriptor.m_CifgEnabled)
{
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
@@ -2923,9 +2880,8 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
"is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
- "are mismatched");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
if (descriptor.m_PeepholeEnabled)
{
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),