diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/BatchMatMulImpl.cpp | 12 |
2 files changed, 9 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 7055092be2..5334641803 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -4305,7 +4305,7 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX, inputXInfoAfterParams.GetShape()); auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY, - inputXInfoBeforeParams.GetShape()); + inputYInfoBeforeParams.GetShape()); if(inputXInfoAfterParams.GetShape()[axesXToMul.second] != inputYInfoAfterParams.GetShape()[axesYToMul.first]) diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp index 8e169cbab8..e0c36c5db8 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.cpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp @@ -42,13 +42,16 @@ void BatchMatMul::ApplyBatchMatMul() inputXInfo.GetShape()); auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY, inputYInfo.GetShape()); + + // the inputYRowSize (or inputXColSize) needs to be obtained using the original (unadjusted) axis value, + // because it's obtained from the original tensor shape + unsigned int inputYRowSize = inputYInfo.GetShape()[axesYToMul.first]; + AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul); unsigned int inputXColDim = axesXToMul.second; unsigned int inputYRowDim = axesYToMul.first; - unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; - auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx) { float sum = 0.0f; @@ -437,10 +440,11 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned { unsigned int result = idx[idx.size()-1]; unsigned int dimMultiplier = 1; - unsigned int offset; + unsigned int offset = 0; // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) - for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--) + // Check offset in relation to i, to stop calculating flat index once all input shape fields considered + for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0 && (i + 1) > offset; i--) { switch(type) { |