From f3e016f7aedad5bc995932e75821f2eeb2ec45ea Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 2 Nov 2021 01:15:50 +0000 Subject: more ERROR_IF fixes - TRANSPOSE: move perm attribute check to compile-time checker - TABLE: add output type checker Signed-off-by: Kevin Cheng Change-Id: I834a5f290fbc384ef339b624060e6e5c77072c36 --- reference_model/src/ops/data_layout.cc | 20 ++++++++++++++------ reference_model/src/ops/ewise_binary.cc | 1 + 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 05a11e0..387152f 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -661,18 +661,26 @@ int OpTranspose::checkTensorAttributes() ASSERT_MEM(in && out); + ERROR_IF(attribute->perm().size() != Rank, "OpTranspose: perm array size needs to match rank(input)"); + + std::array index_used; + index_used.fill(false); + for (int32_t d = 0; d < Rank; d++) + { + int32_t index = attribute->perm()[d]; + ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary"); + ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute"); + index_used[index] = true; + ERROR_IF(in->getShape()[index] != out->getShape()[d], "OpTranspose: input output shape mismatch"); + perm_array[d] = index; + } + return 0; } template int OpTranspose::eval() { - for (int32_t d = 0; d < Rank; d++) - { - perm_array[d] = attribute->perm()[d]; - ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary"); - } - out->getTensor() = in->getTensor().shuffle(perm_array); return GraphNode::eval(); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 415cd1c..b199f69 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -512,6 +512,7 @@ int OpTable::checkTensorAttributes() } ERROR_IF(inputs[0]->getDtype() != InDtype, "OpTable: Unexpected input type"); + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTable: Unexpected output type"); ERROR_IF(attribute->table().size() != TableNumEntries, "OpTable: table attribute size must be %u", TableNumEntries); for (uint32_t i = 0; i < TableNumEntries; i++) -- cgit v1.2.1