aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BatchMatMulImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BatchMatMulImpl.hpp')
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.hpp69
1 files changed, 33 insertions, 36 deletions
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
index 25b6c85d77..19971a4af3 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.hpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -15,6 +15,15 @@ namespace armnn
class BatchMatMul {
public:
+ BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder);
+
+private:
enum DataSlot
{
InputX = 0,
@@ -22,31 +31,35 @@ public:
Output = 2
};
- BatchMatMul(const BatchMatMulDescriptor& params,
- const TensorInfo& inputXInfo,
- const TensorInfo& inputYInfo,
- const TensorInfo& outputInfo,
- Decoder<float>& inputXDecoder,
- Decoder<float>& inputYDecoder,
- Encoder<float>& outputEncoder)
- : params(params),
- inputXInfo(inputXInfo),
- inputYInfo(inputYInfo),
- outputInfo(outputInfo),
- inputXDecoder(inputXDecoder),
- inputYDecoder(inputYDecoder),
- outputEncoder(outputEncoder)
- {}
+ const BatchMatMulDescriptor& params;
+ TensorInfo inputXInfo;
+ TensorInfo inputYInfo;
+ TensorInfo outputInfo;
+ Decoder<float>& inputXDecoder;
+ Decoder<float>& inputYDecoder;
+ Encoder<float>& outputEncoder;
- void BatchMatMulImpl();
+ std::vector<float> inputXData;
+ std::vector<float> inputYData;
+
+ void ApplyBatchMatMul();
+
+ void ApplyParams();
+
+ void Transpose(DataSlot type);
- void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+ void Adjoint(DataSlot type);
+
+ void RecurseTensor(const TensorInfo& tensorInfo,
+ std::function<void(const std::vector<unsigned int>&)> const& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim);
// Adjusts it for when input tensors are of unequal rank
- void AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+ void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul);
- float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+ float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
@@ -54,22 +67,6 @@ public:
void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
-
- template <typename T>
- std::string StringifyVec(const std::vector<T>& vec);
-
-private:
- const BatchMatMulDescriptor& params;
- const TensorInfo& inputXInfo;
- const TensorInfo& inputYInfo;
- const TensorInfo& outputInfo;
- Decoder<float>& inputXDecoder;
- Decoder<float>& inputYDecoder;
- Encoder<float>& outputEncoder;
-
- std::vector<float> inputXData;
- std::vector<float> inputYData;
-
};
} // namespace armnn \ No newline at end of file