aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.hpp')
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.hpp22
1 files changed, 6 insertions, 16 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
index f87f24d88a..c116cdd967 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
@@ -43,6 +43,11 @@ private:
std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
+ std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
+
void FreeUnusedTensors();
};
@@ -50,21 +55,6 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor &descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights);
+ const LstmInputParamsInfo& paramsInfo);
} //namespace armnn