From 86c403b654fe6038f26ed7dccb982ffca970b920 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Mon, 6 Jun 2022 20:46:01 -0700 Subject: Align the serialization schema with TOSA 0.24.0 specification The operators are pool, conv, reshape, slice, transpose, and table. Signed-off-by: TatWai Chong Change-Id: I13f8d626df59be14361068222746347ba69d2fb5 --- reference_model/src/ops/data_layout.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'reference_model/src/ops/data_layout.cc') diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 24c86ed..df7084d 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -201,6 +201,9 @@ int OpPad::eval() case DType_FLOAT: pad_value = (InEigenType)attribute->pad_const_fp(); break; + default: + printNodeValidationError("Unsupported data type"); + break; } if (this->qinfo && Dtype == DType_INT8) @@ -256,7 +259,7 @@ int OpReshape::checkTensorAttributes() for (uint32_t d = 0; d < OutRank; d++) { - ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d], + ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); } @@ -271,7 +274,7 @@ int OpReshape::eval() { for (int32_t d = 0; d < OutRank; d++) { - array_shape[d] = attribute->shape()[OutRank - 1 - d]; + array_shape[d] = attribute->new_shape()[OutRank - 1 - d]; out_reverser[d] = OutRank - 1 - d; } @@ -418,13 +421,13 @@ int OpSlice::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); - ERROR_IF((int32_t)attribute->begin().size() != in->getRank(), + ERROR_IF((int32_t)attribute->start().size() != in->getRank(), "OpSlice: begin array length needs to be rank(input)"); ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); for (int32_t i = 0; i < in->getRank(); i++) { - int32_t b = attribute->begin()[i]; + int32_t b = attribute->start()[i]; int32_t s = attribute->size()[i]; ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary"); ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary"); @@ -629,13 +632,13 @@ int OpTranspose::checkTensorAttributes() ASSERT_MEM(in && out); - ERROR_IF(attribute->perm().size() != Rank, "OpTranspose: perm array size needs to match rank(input)"); + ERROR_IF(attribute->perms().size() != Rank, "OpTranspose: perms array size needs to match rank(input)"); std::array index_used; index_used.fill(false); for (int32_t d = 0; d < Rank; d++) { - int32_t index = attribute->perm()[d]; + int32_t index = attribute->perms()[d]; ERROR_IF(index < 0 or index >= Rank, "OpTranspose: index out of boundary"); ERROR_IF(index_used[index], "OpTranspose: index duplicated in perm attribute"); index_used[index] = true; -- cgit v1.2.1