diff options
Diffstat (limited to 'src/backends/reference/workloads/BatchMatMulImpl.hpp')
-rw-r--r-- | src/backends/reference/workloads/BatchMatMulImpl.hpp | 69 |
1 files changed, 33 insertions, 36 deletions
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp index 25b6c85d77..19971a4af3 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.hpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp @@ -15,6 +15,15 @@ namespace armnn class BatchMatMul { public: + BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder<float>& inputXDecoder, + Decoder<float>& inputYDecoder, + Encoder<float>& outputEncoder); + +private: enum DataSlot { InputX = 0, @@ -22,31 +31,35 @@ public: Output = 2 }; - BatchMatMul(const BatchMatMulDescriptor& params, - const TensorInfo& inputXInfo, - const TensorInfo& inputYInfo, - const TensorInfo& outputInfo, - Decoder<float>& inputXDecoder, - Decoder<float>& inputYDecoder, - Encoder<float>& outputEncoder) - : params(params), - inputXInfo(inputXInfo), - inputYInfo(inputYInfo), - outputInfo(outputInfo), - inputXDecoder(inputXDecoder), - inputYDecoder(inputYDecoder), - outputEncoder(outputEncoder) - {} + const BatchMatMulDescriptor& params; + TensorInfo inputXInfo; + TensorInfo inputYInfo; + TensorInfo outputInfo; + Decoder<float>& inputXDecoder; + Decoder<float>& inputYDecoder; + Encoder<float>& outputEncoder; - void BatchMatMulImpl(); + std::vector<float> inputXData; + std::vector<float> inputYData; + + void ApplyBatchMatMul(); + + void ApplyParams(); + + void Transpose(DataSlot type); - void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim); + void Adjoint(DataSlot type); + + void RecurseTensor(const TensorInfo& tensorInfo, + std::function<void(const std::vector<unsigned int>&)> const& operation, + std::vector<unsigned int>& curIdx, + unsigned int curDim); // Adjusts it for when input tensors are of unequal rank - void AdjustAxesToMulForUnequalRanks( - std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul); + void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul, + std::pair<unsigned int, unsigned int>& axesYToMul); - float GetValueAt(DataSlot type, std::vector<unsigned int> idx); + float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {}); void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx); @@ -54,22 +67,6 @@ public: void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx); unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx); - - template <typename T> - std::string StringifyVec(const std::vector<T>& vec); - -private: - const BatchMatMulDescriptor& params; - const TensorInfo& inputXInfo; - const TensorInfo& inputYInfo; - const TensorInfo& outputInfo; - Decoder<float>& inputXDecoder; - Decoder<float>& inputYDecoder; - Encoder<float>& outputEncoder; - - std::vector<float> inputXData; - std::vector<float> inputYData; - }; } // namespace armnn
\ No newline at end of file |