aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMFixture.h7
1 files changed, 6 insertions, 1 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index be3a3cd735..884b13da80 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -169,7 +169,12 @@ protected:
memcpy(c.data() + i * n, c.data(), n * sizeof(T));
}
}
-
+
+ /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
+ therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
+ in order to be able to call reference implementation that works with (B x M x K) input.
+ Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
+
// Define transposed shapes
TensorShape a_transposed_shape(a.shape().y(), a.shape().x());
TensorShape b_transposed_shape(b.shape().y(), b.shape().x());