aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp2
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp12
2 files changed, 9 insertions, 5 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 7055092be2..5334641803 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -4305,7 +4305,7 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
inputXInfoAfterParams.GetShape());
auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
- inputXInfoBeforeParams.GetShape());
+ inputYInfoBeforeParams.GetShape());
if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
!= inputYInfoAfterParams.GetShape()[axesYToMul.first])
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
index 8e169cbab8..e0c36c5db8 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.cpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -42,13 +42,16 @@ void BatchMatMul::ApplyBatchMatMul()
inputXInfo.GetShape());
auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
inputYInfo.GetShape());
+
+ // the inputYRowSize (or inputXColSize) needs to be obtained using the original (unadjusted) axis value,
+ // because it's obtained from the original tensor shape
+ unsigned int inputYRowSize = inputYInfo.GetShape()[axesYToMul.first];
+
AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
unsigned int inputXColDim = axesXToMul.second;
unsigned int inputYRowDim = axesYToMul.first;
- unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
-
auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
{
float sum = 0.0f;
@@ -437,10 +440,11 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned
{
unsigned int result = idx[idx.size()-1];
unsigned int dimMultiplier = 1;
- unsigned int offset;
+ unsigned int offset = 0;
// -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
- for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
+ // Check offset in relation to i, to stop calculating flat index once all input shape fields considered
+ for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0 && (i + 1) > offset; i--)
{
switch(type)
{