/* * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include "input.h" #include "model.h" #include "output.h" #include "test_assertions.hpp" using namespace EthosU; namespace { int64_t defaultTimeout = 60000000000; void testPing(const Device &device) { int r; try { r = device.ioctl(ETHOSU_IOCTL_PING); } catch (std::exception &e) { throw TestFailureException("Ping test: ", e.what()); } TEST_ASSERT(r == 0); } void testVersion(const Device &device) { int r; try { r = device.ioctl(ETHOSU_IOCTL_VERSION_REQ); } catch (std::exception &e) { throw TestFailureException("Version test: ", e.what()); } TEST_ASSERT(r == 0); } void testCapabilties(const Device &device) { Capabilities capabilities; try { capabilities = device.capabilities(); } catch (std::exception &e) { throw TestFailureException("Capabilities test: ", e.what()); } TEST_ASSERT(capabilities.hwId.architecture > SemanticVersion()); } void testNetworkInfoNotExistentIndex(const Device &device) { try { Network(device, 0); FAIL(); } catch (Exception &e) { // good it should have thrown } catch (std::exception &e) { throw TestFailureException("NetworkInfo no index test: ", e.what()); } } void testNetworkInfoBuffer(const Device &device) { try { std::shared_ptr buffer = std::make_shared(device, sizeof(networkModelData)); buffer->resize(sizeof(networkModelData)); std::memcpy(buffer->data(), networkModelData, sizeof(networkModelData)); Network network(device, buffer); TEST_ASSERT(network.getIfmDims().size() == 1); TEST_ASSERT(network.getOfmDims().size() == 1); } catch (std::exception &e) { throw TestFailureException("NetworkInfo buffer test: ", e.what()); } } void testNetworkInfoUnparsableBuffer(const Device &device) { try { auto buffer = std::make_shared(device, sizeof(networkModelData) / 4); buffer->resize(sizeof(networkModelData) / 4); std::memcpy(buffer->data(), networkModelData + sizeof(networkModelData) / 4, sizeof(networkModelData) / 4); try { Network network(device, buffer); FAIL(); } catch (Exception) { // good, it should have thrown! } } catch (std::exception &e) { throw TestFailureException("NetworkInfo unparsable buffer test: ", e.what()); } } void testRunInferenceBuffer(const Device &device) { try { auto networkBuffer = std::make_shared(device, sizeof(networkModelData)); networkBuffer->resize(sizeof(networkModelData)); std::memcpy(networkBuffer->data(), networkModelData, sizeof(networkModelData)); auto network = std::make_shared(device, networkBuffer); std::vector> inputBuffers; std::vector> outputBuffers; auto inputBuffer = std::make_shared(device, sizeof(inputData)); inputBuffer->resize(sizeof(inputData)); std::memcpy(inputBuffer->data(), inputData, sizeof(inputData)); inputBuffers.push_back(inputBuffer); outputBuffers.push_back(std::make_shared(device, sizeof(expectedOutputData))); std::vector enabledCounters(Inference::getMaxPmuEventCounters()); auto inference = std::make_shared(network, inputBuffers.begin(), inputBuffers.end(), outputBuffers.begin(), outputBuffers.end(), enabledCounters, false); bool timedout = inference->wait(defaultTimeout); TEST_ASSERT(!timedout); InferenceStatus status = inference->status(); TEST_ASSERT(status == InferenceStatus::OK); bool success = inference->cancel(); TEST_ASSERT(!success); TEST_ASSERT(std::memcmp(expectedOutputData, outputBuffers[0]->data(), sizeof(expectedOutputData)) == 0); } catch (std::exception &e) { throw TestFailureException("Inference run test: ", e.what()); } } } // namespace int main() { Device device; try { testPing(device); testVersion(device); testCapabilties(device); testNetworkInfoNotExistentIndex(device); testNetworkInfoBuffer(device); testNetworkInfoUnparsableBuffer(device); testRunInferenceBuffer(device); } catch (TestFailureException &e) { std::cerr << "Test failure: " << e.what() << std::endl; return 1; } return 0; }