// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "Encoders.hpp" #include "Decoders.hpp" #include namespace armnn { class BatchMatMul { public: BatchMatMul(const BatchMatMulDescriptor& params, const TensorInfo& inputXInfo, const TensorInfo& inputYInfo, const TensorInfo& outputInfo, Decoder& inputXDecoder, Decoder& inputYDecoder, Encoder& outputEncoder); private: enum DataSlot { InputX = 0, InputY = 1, Output = 2 }; const BatchMatMulDescriptor& params; TensorInfo inputXInfo; TensorInfo inputYInfo; TensorInfo outputInfo; Decoder& inputXDecoder; Decoder& inputYDecoder; Encoder& outputEncoder; std::vector inputXData; std::vector inputYData; void ApplyBatchMatMul(); void ApplyParams(); void Transpose(DataSlot type); void Adjoint(DataSlot type); void RecurseTensor(const TensorInfo& tensorInfo, std::function&)> const& operation, std::vector& curIdx, unsigned int curDim); // Adjusts it for when input tensors are of unequal rank void AdjustAxesToMulForUnequalRanks(std::pair& axesXToMul, std::pair& axesYToMul); float GetValueAt(DataSlot type, std::vector idx, const std::vector& customData = {}); void SetValueAt(float value, DataSlot type, std::vector idx); // Takes into account broadcasting void AdjustToSafeIdx(DataSlot type, std::vector& idx); unsigned int CalcFlatIdx(DataSlot type, const std::vector& idx); }; } // namespace armnn