aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp40
1 files changed, 18 insertions, 22 deletions
diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
index c911afb237..8dba719d91 100644
--- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -39,7 +39,7 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
GetGuid());
const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
- arm_compute::ITensor& output = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ arm_compute::ITensor& output = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
TensorInfo inputInfo = info.m_InputTensorInfos[0];
TensorInfo outputInfo = info.m_OutputTensorInfos[0];
@@ -49,7 +49,7 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
TensorShape inputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetShape();
TensorShape cellStateLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetShape();
- TensorShape outputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetShape();
+ TensorShape outputLayerShape = static_cast<IAclTensorHandle*>(m_Data.m_Outputs[2])->GetShape();
unsigned int maxTime = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
unsigned int batchSize = m_Data.m_Parameters.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
@@ -288,7 +288,7 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
// LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
if (maxTime == 1 && m_Data.m_Parameters.m_TimeMajor)
{
- TensorShape inputShape = GetTensorShape((&input)->info()->tensor_shape(), 1U);
+ TensorShape inputShape = GetTensorShape(input.info()->tensor_shape(), 1U);
TensorShape outputShape = GetTensorShape((&output)->info()->tensor_shape(), 1U);
TensorShape inputShapeShrink({inputShape[1], inputShape[2]});
@@ -297,10 +297,10 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
auto acl_input_shape_shrink = BuildArmComputeTensorShape(inputShapeShrink);
auto acl_output_shape_shrink = BuildArmComputeTensorShape(outputShapeShrink);
- (&input)->info()->set_tensor_shape(acl_input_shape_shrink);
+ input.info()->set_tensor_shape(acl_input_shape_shrink);
inputLSTM = const_cast<arm_compute::ITensor*>(&input);
- (&output)->info()->set_tensor_shape(acl_output_shape_shrink);
+ output.info()->set_tensor_shape(acl_output_shape_shrink);
outputLSTM = &output;
}
// If there is only one LSTM batch major batch, we will not concat, only permute.
@@ -432,9 +432,9 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
unsigned int aclAxisConcat = CalcAclAxis(concatDescriptor.GetNumDimensions(), concatDescriptor.GetConcatAxis());
if (!m_Data.m_Parameters.m_TimeMajor)
{
- TensorInfo concatOuputTensorInfo = outputInfo;
- concatOuputTensorInfo.SetShape(timeMajorShapeOutput);
- BuildArmComputeTensor(concat_out, concatOuputTensorInfo);
+ TensorInfo concatOutputTensorInfo = outputInfo;
+ concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
+ BuildArmComputeTensor(concat_out, concatOutputTensorInfo);
armcomputetensorutils::InitialiseArmComputeTensorEmpty(concat_out);
m_Concat->configure(m_ConcatInputs, &concat_out, aclAxisConcat);
@@ -452,11 +452,11 @@ NeonUnidirectionalSequenceLstmFloatWorkload::NeonUnidirectionalSequenceLstmFloat
{
if (!m_Data.m_Parameters.m_TimeMajor)
{
- (&output)->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandBatchMajor));
+ output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandBatchMajor));
}
else
{
- (&output)->info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
+ output.info()->set_tensor_shape(BuildArmComputeTensorShape(shapeExpandTimeMajor));
}
}
@@ -510,14 +510,12 @@ arm_compute::Status
NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo)
{
- IgnoreUnused(hiddenStateOutput, cellStateOutput);
-
TensorShape inputLayerShape = input.GetShape();
TensorShape outputLayerShape = outputStateIn.GetShape();
@@ -612,8 +610,6 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
- const TensorInfo& outputStateOut = TensorInfo(outputStateIn.GetShape(), input.GetDataType());
- const TensorInfo& cellStateOut = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
// The inputs and outputs
const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
@@ -704,7 +700,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ? nullptr :
- &aclInputLayerNormWeightsInfo,
+ &aclInputLayerNormWeightsInfo,
&aclForgetLayerNormWeightsInfo,
&aclCellLayerNormWeightsInfo,
&aclOutputLayerNormWeightsInfo);
@@ -803,9 +799,9 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
TensorShape shapeExpandTimeMajor({1, shape[0], shape[1]});
TensorShape shapeExpandBatchMajor({shape[0], 1, shape[1]});
- TensorInfo concatOuputTensorInfo = TensorInfo(output);
- concatOuputTensorInfo.SetShape(timeMajorShapeOutput);
- arm_compute::TensorInfo aclConcatOuputTensorInfo= BuildArmComputeTensorInfo(concatOuputTensorInfo);
+ TensorInfo concatOutputTensorInfo = TensorInfo(output);
+ concatOutputTensorInfo.SetShape(timeMajorShapeOutput);
+ arm_compute::TensorInfo aclConcatOutputTensorInfo= BuildArmComputeTensorInfo(concatOutputTensorInfo);
if (maxTime != 1) // ACL concat does not work with only one element to concatenate.
{
@@ -819,7 +815,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
if (!descriptor.m_TimeMajor)
{
statusConcat = arm_compute::NEConcatenateLayer::validate(concatInputsTensorInfosPtr,
- &aclConcatOuputTensorInfo,
+ &aclConcatOutputTensorInfo,
aclAxisConcat);
}
else
@@ -853,7 +849,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
// Output now time major. Permute output back to batch major.
if (maxTime != 1)
{
- statusPermute2 = arm_compute::NEPermute::validate(&aclConcatOuputTensorInfo,
+ statusPermute2 = arm_compute::NEPermute::validate(&aclConcatOutputTensorInfo,
&aclOutputInfo,
arm_compute::PermutationVector(0U, 2U, 1U));
}