From e5339e7013cf24e5a34509fb0a60377e5f8a244e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 28 Jul 2021 17:33:28 +0100 Subject: MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport * Add RefUnidirectionalSequenceLstmWorkload * Refactor Lstm to be able to use for Lstm and SequenceLstm * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c --- .../RefUnidirectionalSequenceLstmWorkload.hpp | 56 ++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp') diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp new file mode 100644 index 0000000000..8ba7bdc0c6 --- /dev/null +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp @@ -0,0 +1,56 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +#include "Encoders.hpp" +#include "Decoders.hpp" + +namespace armnn +{ + +class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload +{ +public: + explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; + + +private: + void Execute(std::vector inputs, std::vector outputs) const; + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; + std::unique_ptr m_InputLayerNormWeights; + std::unique_ptr m_ForgetLayerNormWeights; + std::unique_ptr m_CellLayerNormWeights; + std::unique_ptr m_OutputLayerNormWeights; + + float m_LayerNormEpsilon = static_cast(1e-8); +}; + +} //namespace armnn -- cgit v1.2.1