ArmNN
 22.08
BatchMatMulImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 
12 
13 namespace armnn
14 {
15 
16 class BatchMatMul {
17 public:
18  enum DataSlot
19  {
20  InputX = 0,
21  InputY = 1,
22  Output = 2
23  };
24 
26  const TensorInfo& inputXInfo,
27  const TensorInfo& inputYInfo,
28  const TensorInfo& outputInfo,
29  Decoder<float>& inputXDecoder,
30  Decoder<float>& inputYDecoder,
31  Encoder<float>& outputEncoder)
32  : params(params),
33  inputXInfo(inputXInfo),
34  inputYInfo(inputYInfo),
35  outputInfo(outputInfo),
36  inputXDecoder(inputXDecoder),
37  inputYDecoder(inputYDecoder),
38  outputEncoder(outputEncoder)
39  {}
40 
41  void BatchMatMulImpl();
42 
43  void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
44 
45  // Adjusts it for when input tensors are of unequal rank
47  std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
48 
49  float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
50 
51  void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
52 
53  // Takes into account broadcasting
54  void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
55 
56  unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
57 
58  template <typename T>
59  std::string StringifyVec(const std::vector<T>& vec);
60 
61 private:
62  const BatchMatMulDescriptor& params;
63  const TensorInfo& inputXInfo;
64  const TensorInfo& inputYInfo;
65  const TensorInfo& outputInfo;
66  Decoder<float>& inputXDecoder;
67  Decoder<float>& inputYDecoder;
68  Encoder<float>& outputEncoder;
69 
70  std::vector<float> inputXData;
71  std::vector<float> inputYData;
72 
73 };
74 
75 } // namespace armnn
unsigned int CalcFlatIdx(DataSlot type, const std::vector< unsigned int > &idx)
void AdjustToSafeIdx(DataSlot type, std::vector< unsigned int > &idx)
Copyright (c) 2021 ARM Limited and Contributors.
std::string StringifyVec(const std::vector< T > &vec)
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
A BatchMatMulDescriptor for the BatchMatMul operator.
void RecurseBMM(std::vector< unsigned int > &curIdx, unsigned int curDim)
void AdjustAxesToMulForUnequalRanks(std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int >> &axesToMul)
void SetValueAt(float value, DataSlot type, std::vector< unsigned int > idx)
float GetValueAt(DataSlot type, std::vector< unsigned int > idx)