| # Lint as: python3 |
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=g-bad-import-order |
| |
| """Build and train neural networks.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import argparse |
| import datetime |
| import os # pylint: disable=duplicate-code |
| from data_load import DataLoader |
| |
| import numpy as np # pylint: disable=duplicate-code |
| import tensorflow as tf |
| |
| logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
| tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) |
| |
| |
| def reshape_function(data, label): |
| reshaped_data = tf.reshape(data, [-1, 3, 1]) |
| return reshaped_data, label |
| |
| |
| def calculate_model_size(model): |
| print(model.summary()) |
| var_sizes = [ |
| np.product(list(map(int, v.shape))) * v.dtype.size |
| for v in model.trainable_variables |
| ] |
| print("Model size:", sum(var_sizes) / 1024, "KB") |
| |
| |
| def build_cnn(seq_length): |
| """Builds a convolutional neural network in Keras.""" |
| model = tf.keras.Sequential([ |
| tf.keras.layers.Conv2D( |
| 8, (4, 3), |
| padding="same", |
| activation="relu", |
| input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) |
| tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) |
| tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) |
| tf.keras.layers.Conv2D(16, (4, 1), padding="same", |
| activation="relu"), # (batch, 42, 1, 16) |
| tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) |
| tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) |
| tf.keras.layers.Flatten(), # (batch, 224) |
| tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) |
| tf.keras.layers.Dropout(0.1), # (batch, 16) |
| tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) |
| ]) |
| model_path = os.path.join("./netmodels", "CNN") |
| print("Built CNN.") |
| if not os.path.exists(model_path): |
| os.makedirs(model_path) |
| model.load_weights("./netmodels/CNN/weights.h5") |
| return model, model_path |
| |
| |
| def build_lstm(seq_length): |
| """Builds an LSTM in Keras.""" |
| model = tf.keras.Sequential([ |
| tf.keras.layers.Bidirectional( |
| tf.keras.layers.LSTM(22), |
| input_shape=(seq_length, 3)), # output_shape=(batch, 44) |
| tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) |
| ]) |
| model_path = os.path.join("./netmodels", "LSTM") |
| print("Built LSTM.") |
| if not os.path.exists(model_path): |
| os.makedirs(model_path) |
| return model, model_path |
| |
| |
| def load_data(train_data_path, valid_data_path, test_data_path, seq_length): |
| data_loader = DataLoader( |
| train_data_path, valid_data_path, test_data_path, seq_length=seq_length) |
| data_loader.format() |
| return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ |
| data_loader.valid_data, data_loader.test_len, data_loader.test_data |
| |
| |
| def build_net(args, seq_length): |
| if args.model == "CNN": |
| model, model_path = build_cnn(seq_length) |
| elif args.model == "LSTM": |
| model, model_path = build_lstm(seq_length) |
| else: |
| print("Please input correct model name.(CNN LSTM)") |
| return model, model_path |
| |
| |
| def train_net( |
| model, |
| model_path, # pylint: disable=unused-argument |
| train_len, # pylint: disable=unused-argument |
| train_data, |
| valid_len, |
| valid_data, |
| test_len, |
| test_data, |
| kind): |
| """Trains the model.""" |
| calculate_model_size(model) |
| epochs = 50 |
| batch_size = 64 |
| model.compile( |
| optimizer="adam", |
| loss="sparse_categorical_crossentropy", |
| metrics=["accuracy"]) |
| if kind == "CNN": |
| train_data = train_data.map(reshape_function) |
| test_data = test_data.map(reshape_function) |
| valid_data = valid_data.map(reshape_function) |
| test_labels = np.zeros(test_len) |
| idx = 0 |
| for data, label in test_data: # pylint: disable=unused-variable |
| test_labels[idx] = label.numpy() |
| idx += 1 |
| train_data = train_data.batch(batch_size).repeat() |
| valid_data = valid_data.batch(batch_size) |
| test_data = test_data.batch(batch_size) |
| model.fit( |
| train_data, |
| epochs=epochs, |
| validation_data=valid_data, |
| steps_per_epoch=1000, |
| validation_steps=int((valid_len - 1) / batch_size + 1), |
| callbacks=[tensorboard_callback]) |
| loss, acc = model.evaluate(test_data) |
| pred = np.argmax(model.predict(test_data), axis=1) |
| confusion = tf.math.confusion_matrix( |
| labels=tf.constant(test_labels), |
| predictions=tf.constant(pred), |
| num_classes=4) |
| print(confusion) |
| print("Loss {}, Accuracy {}".format(loss, acc)) |
| # Convert the model to the TensorFlow Lite format without quantization |
| converter = tf.lite.TFLiteConverter.from_keras_model(model) |
| tflite_model = converter.convert() |
| |
| # Save the model to disk |
| open("model.tflite", "wb").write(tflite_model) |
| |
| # Convert the model to the TensorFlow Lite format with quantization |
| converter = tf.lite.TFLiteConverter.from_keras_model(model) |
| converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] |
| tflite_model = converter.convert() |
| |
| # Save the model to disk |
| open("model_quantized.tflite", "wb").write(tflite_model) |
| |
| basic_model_size = os.path.getsize("model.tflite") |
| print("Basic model is %d bytes" % basic_model_size) |
| quantized_model_size = os.path.getsize("model_quantized.tflite") |
| print("Quantized model is %d bytes" % quantized_model_size) |
| difference = basic_model_size - quantized_model_size |
| print("Difference is %d bytes" % difference) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(allow_abbrev=False) |
| parser.add_argument("--model", "-m") |
| parser.add_argument("--person", "-p") |
| args = parser.parse_args() |
| |
| seq_length = 128 |
| |
| print("Start to load data...") |
| if args.person == "true": |
| train_len, train_data, valid_len, valid_data, test_len, test_data = \ |
| load_data("./person_split/train", "./person_split/valid", |
| "./person_split/test", seq_length) |
| else: |
| train_len, train_data, valid_len, valid_data, test_len, test_data = \ |
| load_data("./data/train", "./data/valid", "./data/test", seq_length) |
| |
| print("Start to build net...") |
| model, model_path = build_net(args, seq_length) |
| |
| print("Start training...") |
| train_net(model, model_path, train_len, train_data, valid_len, valid_data, |
| test_len, test_data, args.model) |
| |
| print("Training finished!") |