blob: fbc5096c499927346232ab612eab67c8aad89bca [file] [log] [blame]
/*
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "main_functions.h"
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include "constants.h"
#include "model.h"
#include "output_handler.h"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/schema/schema_generated.h>
#include <tensorflow/lite/version.h>
/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
TfLiteTensor *output = nullptr;
int inference_count = 0;
/* Create an area of memory to use for input, output, and intermediate arrays.
* Minimum arena size, at the time of writing. After allocating tensors
* you can retrieve this value by invoking interpreter.arena_used_bytes().
*/
const int kModelArenaSize = 2468;
/* Extra headroom for model + alignment + future interpreter changes. */
const int kExtraArenaSize = 560 + 16 + 100;
const int kTensorArenaSize = kModelArenaSize + kExtraArenaSize;
uint8_t tensor_arena[kTensorArenaSize];
} /* namespace */
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
/* This pulls in all the operation implementations we need.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::AllOpsResolver resolver;
/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
/* Allocate memory from the tensor_arena for the model's tensors. */
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
/* Obtain pointers to the model's input and output tensors. */
input = interpreter->input(0);
output = interpreter->output(0);
/* Keep track of how many inferences we have performed. */
inference_count = 0;
}
/* The name of this function is important for Arduino compatibility. */
void loop(void)
{
/* Calculate an x value to feed into the model. We compare the current
* inference_count to the number of inferences per cycle to determine
* our position within the range of possible x values the model was
* trained on, and use this to calculate a value.
*/
float position = static_cast < float > (inference_count) /
static_cast < float > (kInferencesPerCycle);
float x_val = position * kXrange;
/* Place our calculated x value in the model's input tensor */
input->data.f[0] = x_val;
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x_val: %f\n",
static_cast < double > (x_val));
return;
}
/* Read the predicted y value from the model's output tensor */
float y_val = output->data.f[0];
/* Output the results. A custom HandleOutput function can be implemented
* for each supported hardware target.
*/
HandleOutput(error_reporter, x_val, y_val);
/* Increment the inference_counter, and reset it if we have reached
* the total number per cycle
*/
inference_count += 1;
if (inference_count >= kInferencesPerCycle) {
inference_count = 0;
}
}