blob: d285c522fdc1c7db02eb1d9dd81c040215b2037f [file] [log] [blame]
{
"cells": [
{
"source": [
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1BtkMGSYQOTQ"
},
"source": [
"# Train a gesture recognition model for microcontroller use"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BaFfr7DHRmGF"
},
"source": [
"This notebook demonstrates how to train a 20kb gesture recognition model for [TensorFlow Lite for Microcontrollers](https://tensorflow.org/lite/microcontrollers/overview). It will produce the same model used in the [magic_wand](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/examples/magic_wand) example application.\n",
"\n",
"The model is designed to be used with [Google Colaboratory](https://colab.research.google.com).\n",
"\n",
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/magic_wand/train/train_magic_wand_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xXgS6rxyT7Qk"
},
"source": [
"Training is much faster using GPU acceleration. Before you proceed, ensure you are using a GPU runtime by going to **Runtime -> Change runtime type** and selecting **GPU**. Training will take around 5 minutes on a GPU runtime."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LG6ErX5FRIaV"
},
"source": [
"## Configure dependencies\n",
"\n",
"Run the following cell to ensure the correct version of TensorFlow is used."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "STNft9TrfoVh"
},
"source": [
"We'll also clone the TensorFlow repository, which contains the training scripts, and copy them into our workspace."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ygkWw73dRNda"
},
"outputs": [],
"source": [
"# Clone the repository from GitHub\n",
"!git clone --depth 1 -q https://github.com/tensorflow/tensorflow\n",
"# Copy the training scripts into our workspace\n",
"!cp -r tensorflow/tensorflow/lite/micro/examples/magic_wand/train train"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pXI7R4RehFdU"
},
"source": [
"## Prepare the data\n",
"\n",
"Next, we'll download the data and extract it into the expected location within the training scripts' directory."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "W2Sg2AKzVr2L"
},
"outputs": [],
"source": [
"# Download the data we will use to train the model\n",
"!wget http://download.tensorflow.org/models/tflite/magic_wand/data.tar.gz\n",
"# Extract the data into the train directory\n",
"!tar xvzf data.tar.gz -C train 1>/dev/null"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DNjukI1Sgl2C"
},
"source": [
"We'll then run the scripts that split the data into training, validation, and test sets."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XBqSVpi6Vxss"
},
"outputs": [],
"source": [
"# The scripts must be run from within the train directory\n",
"%cd train\n",
"# Prepare the data\n",
"!python data_prepare.py\n",
"# Split the data by person\n",
"!python data_split_person.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "5-cmVbFvhTvy"
},
"source": [
"## Load TensorBoard\n",
"\n",
"Now, we set up TensorBoard so that we can graph our accuracy and loss as training proceeds."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "CCx6SN9NWRPw"
},
"outputs": [],
"source": [
"# Load TensorBoard\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir logs/scalars"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ERC2Cr4PhaOl"
},
"source": [
"## Begin training\n",
"\n",
"The following cell will begin the training process. Training will take around 5 minutes on a GPU runtime. You'll see the metrics in TensorBoard after a few epochs."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DXmQZgbuWQFO"
},
"outputs": [],
"source": [
"!python train.py --model CNN --person true"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4gXbVzcXhvGD"
},
"source": [
"## Create a C source file\n",
"\n",
"The `train.py` script writes a model, `model.tflite`, to the training scripts' directory.\n",
"\n",
"In the following cell, we convert this model into a C++ source file we can use with TensorFlow Lite for Microcontrollers."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "8wgei4OGe3Nz"
},
"outputs": [],
"source": [
"# Install xxd if it is not available\n",
"!apt-get -qq install xxd\n",
"# Save the file as a C source file\n",
"!xxd -i model.tflite > /content/model.cc\n",
"# Print the source file\n",
"!cat /content/model.cc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Train a gesture recognition model for microcontroller use",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}