diff options
author | Tai Ly <tai.ly@arm.com> | 2024-02-14 22:35:44 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-02-27 11:17:44 -0800 |
commit | 8ead6c48d878346dfadc7fb48ee9ec94ab418a88 (patch) | |
tree | 85778af30da4253a57d50018a9525d7750cf3c64 /reference_model | |
parent | 08965d35f728d93d8b215753b4b270a422fe39c9 (diff) | |
download | reference_model-8ead6c48d878346dfadc7fb48ee9ec94ab418a88.tar.gz |
[reference_model] tosa.fb name changes
This patch adjusts reference model for attribute name changes in
tosa.fb schema, and for obsoleted slice/tile/reshape attributes
also updated examples due to the breaking tosa flatbuffers changes
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I105eb99a4c35f289c5078aed0a7f9cbb6dfe9123
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 12 | ||||
-rw-r--r-- | reference_model/src/ops/control_flow.cc | 38 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.h | 1 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.h | 15 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 4 |
5 files changed, 34 insertions, 36 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index b78be71..4b435ca 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -101,7 +101,7 @@ public: else return 0.f; } - uint32_t nextIndex() + uint32_t nextIndex() override { ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") return _set_data0.nextIndex(); @@ -134,7 +134,7 @@ public: else return (_B * _B / (_KS + 1)) * v; } - uint32_t nextIndex() + uint32_t nextIndex() override { return _set_data.nextIndex(); } @@ -167,7 +167,7 @@ public: else return 0.f; } - uint32_t nextIndex() + uint32_t nextIndex() override { return _set_data.nextIndex(); } @@ -199,7 +199,7 @@ public: else return 0.f; } - uint32_t nextIndex() + uint32_t nextIndex() override { return _set_data.nextIndex(); } @@ -246,7 +246,7 @@ public: else return 0.f; } - uint32_t nextIndex() + uint32_t nextIndex() override { ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") return _set_data0.nextIndex(); @@ -280,7 +280,7 @@ public: else return 0.f; } - uint32_t nextIndex() + uint32_t nextIndex() override { return _set_data.nextIndex(); } diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 6bbc587..ac09bbb 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -181,26 +181,26 @@ int OpCondIf::checkTensorAttributes() cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); ASSERT_MEM(cond); - auto then_region = tsh->GetRegionByName(attribute->then_branch()); - auto else_region = tsh->GetRegionByName(attribute->else_branch()); + auto then_region = tsh->GetRegionByName(attribute->then_graph()); + auto else_region = tsh->GetRegionByName(attribute->else_graph()); if (then_region && else_region) { - // new serialization: then_branch and else_branch point to regions + // new serialization: then_graph and else_graph point to regions then_block = then_region->GetBlocks().front(); else_block = else_region->GetBlocks().front(); } else { - // old serialization: then_branch and else_branch point to blocks in curr_region + // old serialization: then_graph and else_graph point to blocks in curr_region auto region_name = getParentSGT()->getRegionName(); auto curr_region = tsh->GetRegionByName(region_name); - then_block = curr_region->GetBlockByName(attribute->then_branch()); - else_block = curr_region->GetBlockByName(attribute->else_branch()); + then_block = curr_region->GetBlockByName(attribute->then_graph()); + else_block = curr_region->GetBlockByName(attribute->else_graph()); } - ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); + ERROR_IF(!then_block, "OpCondIf: fail to resolve then_graph %s", attribute->then_graph().c_str()); - ERROR_IF(!else_block, "OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str()); + ERROR_IF(!else_block, "OpCondIf: fail to resolve else_graph %s", attribute->else_graph().c_str()); // Make sure operator input/output matches block input/output // Skip the first rank 0 bool tensor on input list @@ -276,7 +276,7 @@ int OpCondIf::eval() { if (evalBlock(then_block, block_inputs, getOutputs())) { - WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str()); + WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_graph().c_str()); return 1; } } @@ -284,7 +284,7 @@ int OpCondIf::eval() { if (evalBlock(else_block, block_inputs, getOutputs())) { - WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str()); + WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_graph().c_str()); return 1; } } @@ -327,11 +327,11 @@ int OpWhileLoop::checkTensorAttributes() return 1; } - auto cond_region = tsh->GetRegionByName(attribute->cond_branch()); - auto body_region = tsh->GetRegionByName(attribute->body_branch()); + auto cond_region = tsh->GetRegionByName(attribute->cond_graph()); + auto body_region = tsh->GetRegionByName(attribute->body_graph()); if (cond_region && body_region) { - // new serialization: then_branch and else_branch point to regions + // new serialization: then_graph and else_graph point to regions cond_block = cond_region->GetBlocks().front(); body_block = body_region->GetBlocks().front(); } @@ -339,12 +339,12 @@ int OpWhileLoop::checkTensorAttributes() { auto region_name = getParentSGT()->getRegionName(); auto curr_region = tsh->GetRegionByName(region_name); - cond_block = curr_region->GetBlockByName(attribute->cond_branch()); - body_block = curr_region->GetBlockByName(attribute->body_branch()); + cond_block = curr_region->GetBlockByName(attribute->cond_graph()); + body_block = curr_region->GetBlockByName(attribute->body_graph()); } - ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); - ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); + ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_graph %s", attribute->cond_graph().c_str()); + ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_graph %s", attribute->body_graph().c_str()); // Make sure operator input/output matches block input/output int32_t num_block_tensor = getInputs().size(); @@ -418,7 +418,7 @@ int OpWhileLoop::eval() { if (evalBlock(cond_block, getInputs(), cond_block_outputs)) { - WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str()); + WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_graph().c_str()); return 1; } bool cond_val = cond_output_ctensor.getTensor()(0); @@ -428,7 +428,7 @@ int OpWhileLoop::eval() { if (evalBlock(body_block, getInputs(), getOutputs())) { - WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str()); + WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_graph().c_str()); return 1; } diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index dee2ae0..802c8a0 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -178,7 +178,6 @@ public: using TOut = Eigen::Tensor<OutEigenType, Rank>; protected: - TosaTileAttribute* attribute; TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TInMultiples>* multiples; TosaReference::TensorTemplate<TOut>* out; diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 06ef36e..1d20066 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -62,10 +62,10 @@ return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \ } -#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ - if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ +#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACC_TYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_TYPE) \ { \ - return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACC_TYPE>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ @@ -80,12 +80,11 @@ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \ } -#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ +#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACC_TYPE) \ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \ - ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ + ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_TYPE) \ { \ - return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \ - id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACC_TYPE>(sgt, attribute, id); \ } #define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ @@ -103,7 +102,7 @@ { \ auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \ ASSERT_MEM(attr); \ - accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \ + accumDType = tosa::EnumValuesDType()[attr->acc_type()]; \ } \ else \ { \ diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 124dc87..609265c 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -496,11 +496,11 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL"); LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL"); - TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype()); + TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->acc_type()); DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " - "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", + "stride=[%d,%d], pad=[%d,%d,%d,%d], acc_type=%s", in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_y, kernel_x, stride_y, stride_x, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); |