aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp112
-rw-r--r--src/backends/reference/RefLayerSupport.hpp4
-rw-r--r--src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp8
3 files changed, 41 insertions, 83 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 66661cb521..919c6db6ff 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -465,57 +465,15 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
"hiddenStateOutputVal, cellStateOutputVal, output}");
}
auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
-
- bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
- bool isCellStateOutput = (infos[5] == TensorInfo());
- if (isHiddenStateOutputOptional && isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isHiddenStateOutputOptional)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
+ return IsUnidirectionalSequenceLstmSupported(infos[0],
+ infos[1],
+ infos[2],
+ infos[3],
+ infos[4],
+ infos[5],
+ desc,
+ lstmParamsInfo.value(),
+ reasonIfUnsupported);
}
case LayerType::Pooling3d:
return IsPooling3dSupported(infos[0],
@@ -2841,9 +2799,9 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported) const
@@ -2852,17 +2810,14 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
IgnoreUnused(paramsInfo);
IgnoreUnused(outputStateIn);
IgnoreUnused(cellStateIn);
+ IgnoreUnused(outputStateOut);
+ IgnoreUnused(cellStateOut);
bool supported = true;
- if (hiddenStateOutput.has_value() || cellStateOutput.has_value())
+ std::array<DataType, 2> supportedTypes =
{
- reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output "
- "and cell state output are not supported at the moment.";
- }
-
- std::array<DataType, 1> supportedTypes =
- {
- DataType::Float32
+ DataType::Float32,
+ DataType::QAsymmS8
};
std::array<DataType, 2> supportedWeightTypes =
@@ -2871,16 +2826,19 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
DataType::QAsymmS8
};
+ std::array<DataType, 3> supportedBiasTypes =
+ {
+ DataType::Float32,
+ DataType::QAsymmS8,
+ DataType::Signed32
+ };
+
// check inputs and outputs
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: input is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched");
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: output is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and output types are mismatched");
// check layer parameters
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes),
reasonIfUnsupported,
@@ -2905,14 +2863,13 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights "
"is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types "
- "are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and OutputGateBias types "
- "are mismatched");
+
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: CellBias is not a supported type.");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type.");
if (!descriptor.m_CifgEnabled)
{
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes),
@@ -2923,9 +2880,8 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported(
reasonIfUnsupported,
"Reference UnidirectionalSequenceLstm: RecurrentToInputWeights "
"is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported,
- "Reference UnidirectionalSequenceLstm: input and InputGateBias types "
- "are mismatched");
+ supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported,
+ "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type.");
if (descriptor.m_PeepholeEnabled)
{
supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes),
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 98770ad64a..aa8bd8dda4 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -367,9 +367,9 @@ public:
const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
index d447a46b23..c4345d4978 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
@@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
TensorInfo inputInfo = GetTensorInfo(inputs[0]);
const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
- TensorInfo outputInfo = GetTensorInfo(outputs[0]);
+ TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
+ TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
+ TensorInfo outputInfo = GetTensorInfo(outputs[2]);
TensorShape& inputShape = inputInfo.GetShape();
TensorShape& outputShape= outputInfo.GetShape();
auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
@@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
- auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
@@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
{
// Permute Output back to batch major
const PermutationVector& mappings = {1U, 0U, 2U};
- auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
outputInfo.SetShape(outputShape);