ArmNN
 24.02
BatchMatMulImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 
12 
13 namespace armnn
14 {
15 
16 class BatchMatMul {
17 public:
18  BatchMatMul(const BatchMatMulDescriptor& params,
19  const TensorInfo& inputXInfo,
20  const TensorInfo& inputYInfo,
21  const TensorInfo& outputInfo,
22  Decoder<float>& inputXDecoder,
23  Decoder<float>& inputYDecoder,
24  Encoder<float>& outputEncoder);
25 
26 private:
27  enum DataSlot
28  {
29  InputX = 0,
30  InputY = 1,
31  Output = 2
32  };
33 
34  const BatchMatMulDescriptor& params;
35  TensorInfo inputXInfo;
36  TensorInfo inputYInfo;
37  TensorInfo outputInfo;
38  Decoder<float>& inputXDecoder;
39  Decoder<float>& inputYDecoder;
40  Encoder<float>& outputEncoder;
41 
42  std::vector<float> inputXData;
43  std::vector<float> inputYData;
44 
45  void ApplyBatchMatMul();
46 
47  void ApplyParams();
48 
49  void Transpose(DataSlot type);
50 
51  void Adjoint(DataSlot type);
52 
53  void RecurseTensor(const TensorInfo& tensorInfo,
54  std::function<void(const std::vector<unsigned int>&)> const& operation,
55  std::vector<unsigned int>& curIdx,
56  unsigned int curDim);
57 
58  // Adjusts it for when input tensors are of unequal rank
59  void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
60  std::pair<unsigned int, unsigned int>& axesYToMul);
61 
62  float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
63 
64  void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
65 
66  // Takes into account broadcasting
67  void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
68 
69  unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
70 };
71 
72 } // namespace armnn
armnn::Decoder< float >
WorkloadData.hpp
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::Encoder< float >
armnn::BatchMatMulDescriptor
A BatchMatMulDescriptor for the BatchMatMul operator.
Definition: Descriptors.hpp:1584
armnn::BatchMatMul
Definition: BatchMatMulImpl.hpp:16
Decoders.hpp
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::BatchMatMul::BatchMatMul
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
Definition: BatchMatMulImpl.cpp:15
Encoders.hpp