blob: 213c7c792e6253b592733d0f15c5d5ae26012360 [file] [log] [blame]
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-bad-import-order
"""Build and train neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import datetime
import os # pylint: disable=duplicate-code
from data_load import DataLoader
import numpy as np # pylint: disable=duplicate-code
import tensorflow as tf
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
def reshape_function(data, label):
reshaped_data = tf.reshape(data, [-1, 3, 1])
return reshaped_data, label
def calculate_model_size(model):
print(model.summary())
var_sizes = [
np.product(list(map(int, v.shape))) * v.dtype.size
for v in model.trainable_variables
]
print("Model size:", sum(var_sizes) / 1024, "KB")
def build_cnn(seq_length):
"""Builds a convolutional neural network in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(
8, (4, 3),
padding="same",
activation="relu",
input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8)
tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8)
tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8)
tf.keras.layers.Conv2D(16, (4, 1), padding="same",
activation="relu"), # (batch, 42, 1, 16)
tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16)
tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16)
tf.keras.layers.Flatten(), # (batch, 224)
tf.keras.layers.Dense(16, activation="relu"), # (batch, 16)
tf.keras.layers.Dropout(0.1), # (batch, 16)
tf.keras.layers.Dense(4, activation="softmax") # (batch, 4)
])
model_path = os.path.join("./netmodels", "CNN")
print("Built CNN.")
if not os.path.exists(model_path):
os.makedirs(model_path)
model.load_weights("./netmodels/CNN/weights.h5")
return model, model_path
def build_lstm(seq_length):
"""Builds an LSTM in Keras."""
model = tf.keras.Sequential([
tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(22),
input_shape=(seq_length, 3)), # output_shape=(batch, 44)
tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4)
])
model_path = os.path.join("./netmodels", "LSTM")
print("Built LSTM.")
if not os.path.exists(model_path):
os.makedirs(model_path)
return model, model_path
def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
data_loader = DataLoader(
train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
data_loader.format()
return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
data_loader.valid_data, data_loader.test_len, data_loader.test_data
def build_net(args, seq_length):
if args.model == "CNN":
model, model_path = build_cnn(seq_length)
elif args.model == "LSTM":
model, model_path = build_lstm(seq_length)
else:
print("Please input correct model name.(CNN LSTM)")
return model, model_path
def train_net(
model,
model_path, # pylint: disable=unused-argument
train_len, # pylint: disable=unused-argument
train_data,
valid_len,
valid_data,
test_len,
test_data,
kind):
"""Trains the model."""
calculate_model_size(model)
epochs = 50
batch_size = 64
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
if kind == "CNN":
train_data = train_data.map(reshape_function)
test_data = test_data.map(reshape_function)
valid_data = valid_data.map(reshape_function)
test_labels = np.zeros(test_len)
idx = 0
for data, label in test_data: # pylint: disable=unused-variable
test_labels[idx] = label.numpy()
idx += 1
train_data = train_data.batch(batch_size).repeat()
valid_data = valid_data.batch(batch_size)
test_data = test_data.batch(batch_size)
model.fit(
train_data,
epochs=epochs,
validation_data=valid_data,
steps_per_epoch=1000,
validation_steps=int((valid_len - 1) / batch_size + 1),
callbacks=[tensorboard_callback])
loss, acc = model.evaluate(test_data)
pred = np.argmax(model.predict(test_data), axis=1)
confusion = tf.math.confusion_matrix(
labels=tf.constant(test_labels),
predictions=tf.constant(pred),
num_classes=4)
print(confusion)
print("Loss {}, Accuracy {}".format(loss, acc))
# Convert the model to the TensorFlow Lite format without quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model to disk
open("model.tflite", "wb").write(tflite_model)
# Convert the model to the TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
# Save the model to disk
open("model_quantized.tflite", "wb").write(tflite_model)
basic_model_size = os.path.getsize("model.tflite")
print("Basic model is %d bytes" % basic_model_size)
quantized_model_size = os.path.getsize("model_quantized.tflite")
print("Quantized model is %d bytes" % quantized_model_size)
difference = basic_model_size - quantized_model_size
print("Difference is %d bytes" % difference)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m")
parser.add_argument("--person", "-p")
args = parser.parse_args()
seq_length = 128
print("Start to load data...")
if args.person == "true":
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./person_split/train", "./person_split/valid",
"./person_split/test", seq_length)
else:
train_len, train_data, valid_len, valid_data, test_len, test_data = \
load_data("./data/train", "./data/valid", "./data/test", seq_length)
print("Start to build net...")
model, model_path = build_net(args, seq_length)
print("Start training...")
train_net(model, model_path, train_len, train_data, valid_len, valid_data,
test_len, test_data, args.model)
print("Training finished!")