22.08
|
#include <BatchMatMulImpl.hpp>
Public Types | |
enum | DataSlot { InputX = 0, InputY = 1, Output = 2 } |
Public Member Functions | |
BatchMatMul (const BatchMatMulDescriptor ¶ms, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder) | |
void | BatchMatMulImpl () |
void | RecurseBMM (std::vector< unsigned int > &curIdx, unsigned int curDim) |
void | AdjustAxesToMulForUnequalRanks (std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int >> &axesToMul) |
float | GetValueAt (DataSlot type, std::vector< unsigned int > idx) |
void | SetValueAt (float value, DataSlot type, std::vector< unsigned int > idx) |
void | AdjustToSafeIdx (DataSlot type, std::vector< unsigned int > &idx) |
unsigned int | CalcFlatIdx (DataSlot type, const std::vector< unsigned int > &idx) |
template<typename T > | |
std::string | StringifyVec (const std::vector< T > &vec) |
Definition at line 16 of file BatchMatMulImpl.hpp.
enum DataSlot |
Enumerator | |
---|---|
InputX | |
InputY | |
Output |
Definition at line 18 of file BatchMatMulImpl.hpp.
|
inline |
Definition at line 25 of file BatchMatMulImpl.hpp.
References BatchMatMul::AdjustAxesToMulForUnequalRanks(), BatchMatMul::AdjustToSafeIdx(), BatchMatMul::BatchMatMulImpl(), BatchMatMul::CalcFlatIdx(), BatchMatMul::GetValueAt(), BatchMatMul::RecurseBMM(), BatchMatMul::SetValueAt(), and BatchMatMul::StringifyVec().
void AdjustAxesToMulForUnequalRanks | ( | std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int >> & | axesToMul | ) |
Definition at line 73 of file BatchMatMulImpl.cpp.
References TensorInfo::GetNumDimensions().
Referenced by BatchMatMul::BatchMatMul(), and BatchMatMul::RecurseBMM().
void AdjustToSafeIdx | ( | DataSlot | type, |
std::vector< unsigned int > & | idx | ||
) |
Definition at line 147 of file BatchMatMulImpl.cpp.
References TensorInfo::GetNumDimensions(), TensorInfo::GetShape(), and armnn::Output.
Referenced by BatchMatMul::BatchMatMul(), BatchMatMul::GetValueAt(), and BatchMatMul::SetValueAt().
void BatchMatMulImpl | ( | ) |
Definition at line 14 of file BatchMatMulImpl.cpp.
References Decoder< IType >::DecodeTensor(), TensorInfo::GetNumDimensions(), TensorInfo::GetShape(), and BatchMatMul::RecurseBMM().
Referenced by BatchMatMul::BatchMatMul().
unsigned int CalcFlatIdx | ( | DataSlot | type, |
const std::vector< unsigned int > & | idx | ||
) |
Definition at line 186 of file BatchMatMulImpl.cpp.
References TensorInfo::GetNumDimensions(), TensorInfo::GetShape(), and armnn::Output.
Referenced by BatchMatMul::BatchMatMul(), BatchMatMul::GetValueAt(), and BatchMatMul::SetValueAt().
float GetValueAt | ( | DataSlot | type, |
std::vector< unsigned int > | idx | ||
) |
Definition at line 96 of file BatchMatMulImpl.cpp.
References BatchMatMul::AdjustToSafeIdx(), BatchMatMul::CalcFlatIdx(), Encoder< IType >::Get(), and armnn::Output.
Referenced by BatchMatMul::BatchMatMul(), and BatchMatMul::RecurseBMM().
void RecurseBMM | ( | std::vector< unsigned int > & | curIdx, |
unsigned int | curDim | ||
) |
Definition at line 29 of file BatchMatMulImpl.cpp.
References BatchMatMul::AdjustAxesToMulForUnequalRanks(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), TensorInfo::GetShape(), BatchMatMul::GetValueAt(), armnn::Output, and BatchMatMul::SetValueAt().
Referenced by BatchMatMul::BatchMatMul(), and BatchMatMul::BatchMatMulImpl().
void SetValueAt | ( | float | value, |
DataSlot | type, | ||
std::vector< unsigned int > | idx | ||
) |
Definition at line 124 of file BatchMatMulImpl.cpp.
References BatchMatMul::AdjustToSafeIdx(), BatchMatMul::CalcFlatIdx(), armnn::Output, and Encoder< IType >::Set().
Referenced by BatchMatMul::BatchMatMul(), and BatchMatMul::RecurseBMM().
std::string StringifyVec | ( | const std::vector< T > & | vec | ) |
Definition at line 219 of file BatchMatMulImpl.cpp.
Referenced by BatchMatMul::BatchMatMul().