// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "Encoders.hpp" #include "Decoders.hpp" #include namespace armnn { class BatchMatMul { public: enum DataSlot { InputX = 0, InputY = 1, Output = 2 }; BatchMatMul(const BatchMatMulDescriptor& params, const TensorInfo& inputXInfo, const TensorInfo& inputYInfo, const TensorInfo& outputInfo, Decoder& inputXDecoder, Decoder& inputYDecoder, Encoder& outputEncoder) : params(params), inputXInfo(inputXInfo), inputYInfo(inputYInfo), outputInfo(outputInfo), inputXDecoder(inputXDecoder), inputYDecoder(inputYDecoder), outputEncoder(outputEncoder) {} void BatchMatMulImpl(); void RecurseBMM(std::vector& curIdx, unsigned int curDim); // Adjusts it for when input tensors are of unequal rank void AdjustAxesToMulForUnequalRanks( std::pair, std::pair>& axesToMul); float GetValueAt(DataSlot type, std::vector idx); void SetValueAt(float value, DataSlot type, std::vector idx); // Takes into account broadcasting void AdjustToSafeIdx(DataSlot type, std::vector& idx); unsigned int CalcFlatIdx(DataSlot type, const std::vector& idx); template std::string StringifyVec(const std::vector& vec); private: const BatchMatMulDescriptor& params; const TensorInfo& inputXInfo; const TensorInfo& inputYInfo; const TensorInfo& outputInfo; Decoder& inputXDecoder; Decoder& inputYDecoder; Encoder& outputEncoder; std::vector inputXData; std::vector inputYData; }; } // namespace armnn