aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-08 09:57:55 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-10 09:15:04 +0000
commitad5293a86e315049de36afd723dcd1a7e70681a7 (patch)
treeb9003cd1fba00c267a971d899284b3fcbd5ce6f5 /src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
parent8b797a84f1e8f9d1d5d064afbc4fc12c21b8ffed (diff)
downloadarmnn-ad5293a86e315049de36afd723dcd1a7e70681a7.tar.gz
IVGCVSW-3337 Add Neon backend support for LSTM layer normalisation
* Update neon lstm workload * Add unit tests * Add isLstmSupported Change-Id: I493c159137f6544b0f2532d16d4fafd7a7e587e5 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/neon/workloads/NeonLstmFloatWorkload.hpp')
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.hpp22
1 files changed, 6 insertions, 16 deletions
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
index f87f24d88a..c116cdd967 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
@@ -43,6 +43,11 @@ private:
std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
+ std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
+
void FreeUnusedTensors();
};
@@ -50,21 +55,6 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor &descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights);
+ const LstmInputParamsInfo& paramsInfo);
} //namespace armnn