diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-04-28 16:29:44 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-04-28 17:37:03 -0700 |
commit | 989cb050228b47085189b1c5cb0d9b705e1060e7 (patch) | |
tree | 317bfa75a3d3b79341f2d0e3c994bcef834ff179 /reference_model/src | |
parent | 550ccc52de231621c0bf0c05ae2a398eec37ff51 (diff) | |
download | reference_model-989cb050228b47085189b1c5cb0d9b705e1060e7.tar.gz |
Support mixed-precision input tensors for TOSA unit test.
Bring CONV2D/DEPTHWISE_CONV2D/TRANSPOSE_CONV2D/FULLY_CONNECTED up running.
Other minor fixes:
- reference model should bail out if shape is invalid, along with "goto done" cleanup.
- cleanup typos/duplicate in tosa_test_gen.py/tosa_serializer.py.
- wrong input_zp/output_zp being generated for RESCALE.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ic1f3fe0090482bdee8a61508be7c738714191e19
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/tensor.h | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d39cc7c..6c0622e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -33,9 +33,7 @@ class GraphNode; class Tensor { public: - Tensor(std::string tensorName_, - DType tensorDtype__, - std::vector<int> shape_); + Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_); virtual ~Tensor(); @@ -240,9 +238,7 @@ template <class T> class TensorTemplate : public Tensor { public: - TensorTemplate(std::string tensorName_, - DType tensorDtype_, - std::vector<int> shape_) + TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_) : Tensor(tensorName_, tensorDtype_, shape_) { tensor = nullptr; @@ -606,11 +602,15 @@ int Tensor6<bool>::dumpTensor(FILE* out) const; class TensorFactory { public: - static Tensor* newTensor(std::string tensorName_, - DType tensorDtype_, - std::vector<int> shape_, - const uint32_t rank) + static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank) { + // Bail out if any dimension is invalid. + for (auto& dim : shape_) + { + if (dim <= 0) + goto done; + } + switch (tensorDtype_) { case DType_FLOAT: @@ -630,9 +630,8 @@ public: return new Tensor5<float>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<float>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -654,9 +653,8 @@ public: return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT48: switch (rank) { @@ -674,9 +672,8 @@ public: return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_BOOL: switch (rank) { @@ -694,16 +691,22 @@ public: return new Tensor5<bool>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<bool>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; default: - goto done; + break; } done: - FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_], - rank); + std::string shape_str("["); + for (auto& dim : shape_) + { + shape_str += (std::to_string(dim) + ", "); + } + shape_str.append("]"); + + FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(), + EnumNamesDType()[tensorDtype_], rank, shape_str.c_str()); } static Tensor* newTensor(DType type, const std::vector<int> shape); |