aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc20
1 files changed, 14 insertions, 6 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 05a11e0..387152f 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -661,18 +661,26 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
+ ERROR_IF(attribute->perm().size() != Rank, "OpTranspose: perm array size needs to match rank(input)");
+
+ std::array<bool, Rank> index_used;
+ index_used.fill(false);
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ int32_t index = attribute->perm()[d];
+ ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary");
+ ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute");
+ index_used[index] = true;
+ ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch");
+ perm_array[d] = index;
+ }
+
return 0;
}
template <int Rank, DType Dtype>
int OpTranspose<Rank, Dtype>::eval()
{
- for (int32_t d = 0; d < Rank; d++)
- {
- perm_array[d] = attribute->perm()[d];
- ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary");
- }
-
out->getTensor() = in->getTensor().shuffle(perm_array);
return GraphNode::eval();