aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/include/operators.h4
-rw-r--r--reference_model/src/operators.cc33
-rw-r--r--reference_model/src/ops/activation_funcs.cc29
-rw-r--r--reference_model/src/ops/activation_funcs.h14
-rw-r--r--reference_model/src/ops/op_factory.cc6
5 files changed, 84 insertions, 2 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 6efb655..b12604f 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -150,6 +150,8 @@ extern "C"
tosa_status_t tosa_run_tanh(tosa_tensor_t client_input, tosa_tensor_t client_output);
+ tosa_status_t tosa_run_erf(tosa_tensor_t client_input, tosa_tensor_t client_output);
+
tosa_status_t tosa_run_add(tosa_tensor_t client_input1, tosa_tensor_t client_input2, tosa_tensor_t client_output);
tosa_status_t tosa_run_arithmetic_right_shift(tosa_tensor_t client_input1,
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index a0b5013..5796129 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -580,6 +580,37 @@ extern "C"
return tosa_status_valid;
}
+ tosa_status_t tosa_run_erf(tosa_tensor_t client_input, tosa_tensor_t client_output)
+ {
+ // Create operator attributes
+ TosaNoneAttribute attr;
+
+ // Create tensors
+ tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
+ tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_ERF, tosa::Attribute::Attribute_NONE, &attr,
+ { input->GetName() }, { output->GetName() });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("erf", "main", { op }, { input, output }, { input->GetName() },
+ { output->GetName() });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner;
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size));
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ TOSA_RETURN_ON_ERROR(runner.getOutput(output->GetName(), client_output.data, client_output.size));
+
+ return tosa_status_valid;
+ }
+
tosa_status_t tosa_run_add(tosa_tensor_t client_input1, tosa_tensor_t client_input2, tosa_tensor_t client_output)
{
// Create operator attributes
@@ -2324,4 +2355,4 @@ extern "C"
return tosa_status_valid;
}
-} // extern "C" \ No newline at end of file
+} // extern "C"
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 6681d6d..12d0697 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -124,6 +124,30 @@ int OpTanh<Rank, Dtype>::register_fcn()
return 0;
}
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpErf<Rank, Dtype>::register_fcn()
+{
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be similar than or equal to MAX_RANK");
+
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(erff(a)); };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return erf(a); };
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return 0;
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
@@ -141,3 +165,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP64);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 2372fcb..a7e1275 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -77,6 +77,20 @@ public:
virtual int register_fcn();
};
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpErf : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpErf(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_ERF, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0a78884..a3069dc 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -161,6 +161,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
break;
+ case Op_ERF:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpErf, FP64);
+ break;
// ewise_binary
case Op_ADD: