42 unsigned int inputXColDim = axesToMul.first.second;
43 unsigned int inputYRowDim = axesToMul.second.first;
45 unsigned int inputYRowSize = inputYInfo.
GetShape()[inputYRowDim];
50 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
52 xIdx[inputXColDim] = inputYRowIdx;
55 yIdx[inputYRowDim] = inputYRowIdx;
66 for (
unsigned int i = 0; i < outputInfo.
GetShape()[curDim]; i++)
74 std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
85 axesToMul.first.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
86 axesToMul.first.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
91 axesToMul.second.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
92 axesToMul.second.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
107 case DataSlot::InputX:
108 value = inputXData[flatIdx];
110 case DataSlot::InputY:
111 value = inputYData[flatIdx];
114 outputEncoder[flatIdx];
115 value = outputEncoder.
Get();
132 case DataSlot::InputX:
133 inputXData[flatIdx] = value;
135 case DataSlot::InputY:
136 inputYData[flatIdx] = value;
139 outputEncoder[flatIdx];
140 outputEncoder.
Set(value);
149 for(
unsigned int dim = 0; dim < idx.size(); dim++)
153 case DataSlot::InputX:
158 idx[dim] > inputXInfo.
GetShape()[dim-xDiff]-1)
164 case DataSlot::InputY:
169 idx[dim] > inputYInfo.
GetShape()[dim-yDiff]-1)
188 unsigned int result = idx[idx.size()-1];
190 unsigned int dimMultiplier = 1;
195 for(
unsigned int i = static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0; i--)
199 case DataSlot::InputX:
201 dimMultiplier *= inputXInfo.
GetShape()[i + 1 - offset];
203 case DataSlot::InputY:
205 dimMultiplier *= inputYInfo.
GetShape()[i + 1 - offset];
208 dimMultiplier *= outputInfo.
GetShape()[i+1];
213 result += (idx[i] * dimMultiplier);
218 template <
typename T>
221 std::string res =
"{ ";
224 res += std::to_string(x);
const TensorShape & GetShape() const
unsigned int CalcFlatIdx(DataSlot type, const std::vector< unsigned int > &idx)
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
virtual void Set(IType right)=0
void AdjustToSafeIdx(DataSlot type, std::vector< unsigned int > &idx)
Copyright (c) 2021 ARM Limited and Contributors.
std::string StringifyVec(const std::vector< T > &vec)
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Static helper to get the two axes (for each input) for multiplication.
virtual IType Get() const =0
void RecurseBMM(std::vector< unsigned int > &curIdx, unsigned int curDim)
void AdjustAxesToMulForUnequalRanks(std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int >> &axesToMul)
void SetValueAt(float value, DataSlot type, std::vector< unsigned int > idx)
unsigned int GetNumDimensions() const
float GetValueAt(DataSlot type, std::vector< unsigned int > idx)