aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-05-12 10:44:49 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-05-12 14:51:16 -0700
commit14d7f7a2b5d0d85b83d8c84a5456828feb1a0ea1 (patch)
treec0e5eaa7e119d7998f5780a6c90947e875eddf84 /reference_model/src/ops
parentd267dd9418374d49ac1e1a1a9c9b1d30b5733ee9 (diff)
downloadreference_model-14d7f7a2b5d0d85b83d8c84a5456828feb1a0ea1.tar.gz
Update to v0.22.0
- remove identityN and placeholder - add div - update serialization_lib hash - update apply_scale_16() assertion - regenerate examples/ due to serialization_lib change Change-Id: I7183d92bec33697c65adfc07cb8eb89c6882675a
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/control_flow.cc5
-rw-r--r--reference_model/src/ops/data_nodes.cc78
-rw-r--r--reference_model/src/ops/data_nodes.h30
-rw-r--r--reference_model/src/ops/ewise_binary.cc23
-rw-r--r--reference_model/src/ops/ewise_binary.h1
-rw-r--r--reference_model/src/ops/op_factory.cc12
6 files changed, 31 insertions, 118 deletions
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 827e01f..1a6a63a 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -93,6 +93,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
return 1;
}
+ tensor->setIsValid();
+
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{
@@ -292,8 +294,7 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
- TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL,
- std::vector<int32_t>({}));
+ TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 883cd1b..baae019 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -42,29 +42,6 @@ int OpConst::eval()
return GraphNode::eval();
}
-OpPlaceholder::OpPlaceholder(uint64_t id_)
- : GraphNode(Op_PLACEHOLDER, id_)
-{
- setRequiredOperands(0, 1);
-}
-
-OpPlaceholder::~OpPlaceholder()
-{}
-
-int OpPlaceholder::checkTensorAttributes()
-{
- if (validateRequiredOperands())
- return 1;
-
- return 0;
-}
-
-int OpPlaceholder::eval()
-{
- // Evaluation is trivial for placeholders
- return GraphNode::eval();
-}
-
template <int Rank, DType Dtype>
OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_IDENTITY, id_)
@@ -107,64 +84,11 @@ int OpIdentity<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
-OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_IDENTITYN, id_)
-{
- setRequiredRank(0, 6);
-}
-
-template <int Rank, DType Dtype>
-OpIdentityN<Rank, Dtype>::~OpIdentityN()
-{}
-
-template <int Rank, DType Dtype>
-int OpIdentityN<Rank, Dtype>::checkTensorAttributes()
-{
-
- if (inputs.size() != outputs.size())
- {
- printNodeValidationError("Input and output tensor list lengths are not equal");
- return 1;
- }
-
- for (size_t i = 0; i < inputs.size(); i++)
- {
- ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
- outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i]));
-
- if (ins[i]->matchRankTypeShape(*outs[i]))
- {
- printNodeValidationError("Input and output tensor rank, type, or shape do not match");
- return 1;
- }
- }
-
- return 0;
-}
-
-template <int Rank, DType Dtype>
-int OpIdentityN<Rank, Dtype>::eval()
-{
- for (size_t i = 0; i < ins.size(); i++)
- {
- outs[i]->getTensor() = ins[i]->getTensor();
- }
-
- return GraphNode::eval();
-}
-
// template explicit instantiation
-// note OpConst and OpPlaceholder are not templated
+// note OpConst is not templated
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
index bec4669..a02d441 100644
--- a/reference_model/src/ops/data_nodes.h
+++ b/reference_model/src/ops/data_nodes.h
@@ -31,16 +31,6 @@ public:
virtual int eval();
};
-class OpPlaceholder : public GraphNode
-{
-public:
- OpPlaceholder(uint64_t id_);
- virtual ~OpPlaceholder();
-
- virtual int checkTensorAttributes();
- virtual int eval();
-};
-
template <int Rank, DType Dtype>
class OpIdentity : public GraphNode
{
@@ -61,26 +51,6 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
-class OpIdentityN : public GraphNode
-{
-public:
- OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
- virtual ~OpIdentityN();
-
- virtual int checkTensorAttributes();
- virtual int eval();
-
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, Rank>;
-
-protected:
- std::vector<TosaReference::TensorTemplate<TIn>*> ins;
- std::vector<TosaReference::TensorTemplate<TOut>*> outs;
-};
-
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index fc587f1..76cebeb 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -298,6 +298,27 @@ int OpBitwiseXor<Rank, Dtype>::register_fcn()
}
template <int Rank, DType Dtype>
+int OpDiv<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_INT32:
+ this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
+ ASSERT_MSG_NODE(b != 0, "OpDiv: divisor must be non-zero value");
+ int64_t res_in_64 = static_cast<int64_t>(a) / b;
+ int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
+ ASSERT_MSG_NODE(a <= i32_max_in_64, "OpDiv: result not in i32 range");
+ return static_cast<InEigenType>(res_in_64);
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
int OpLogicalAnd<Rank, Dtype>::register_fcn()
{
switch (Dtype)
@@ -579,6 +600,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpDiv, INT32);
+
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 5bc5630..6b9c98d 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -125,6 +125,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Add, ADD)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseAnd, BITWISE_AND)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseOr, BITWISE_OR)
DEF_TEMPLATE_BINARY_OP_DEFAULT(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_DEFAULT(Div, DIV)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalAnd, LOGICAL_AND)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
DEF_TEMPLATE_BINARY_OP_DEFAULT(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index b326c63..440d624 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -134,6 +134,9 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
break;
+ case Op_DIV:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpDiv, INT32);
+ break;
case Op_LOGICAL_AND:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
break;
@@ -346,8 +349,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// data_nodes
case Op_CONST:
return new OpConst(id);
- case Op_PLACEHOLDER:
- return new OpPlaceholder(id);
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
@@ -355,13 +356,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
break;
- case Op_IDENTITYN:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
- break;
// type_conversion
case Op_CAST: