| # Lint as: python3 |
| # coding=utf-8 |
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Mix and split data. |
| |
| Mix different people's data together and randomly split them into train, |
| validation and test. These data would be saved separately under "/data". |
| It will generate new files with the following structure: |
| |
| ├── data |
| │ ├── complete_data |
| │ ├── test |
| │ ├── train |
| │ └── valid |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import json |
| import random |
| from data_prepare import write_data |
| |
| |
| # Read data |
| def read_data(path): |
| data = [] |
| with open(path, "r") as f: |
| lines = f.readlines() |
| for idx, line in enumerate(lines): # pylint: disable=unused-variable |
| dic = json.loads(line) |
| data.append(dic) |
| print("data_length:" + str(len(data))) |
| return data |
| |
| |
| def split_data(data, train_ratio, valid_ratio): |
| """Splits data into train, validation and test according to ratio.""" |
| train_data = [] |
| valid_data = [] |
| test_data = [] |
| num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} |
| for idx, item in enumerate(data): # pylint: disable=unused-variable |
| for i in num_dic: |
| if item["gesture"] == i: |
| num_dic[i] += 1 |
| print(num_dic) |
| train_num_dic = {} |
| valid_num_dic = {} |
| for i in num_dic: |
| train_num_dic[i] = int(train_ratio * num_dic[i]) |
| valid_num_dic[i] = int(valid_ratio * num_dic[i]) |
| random.seed(30) |
| random.shuffle(data) |
| for idx, item in enumerate(data): |
| for i in num_dic: |
| if item["gesture"] == i: |
| if train_num_dic[i] > 0: |
| train_data.append(item) |
| train_num_dic[i] -= 1 |
| elif valid_num_dic[i] > 0: |
| valid_data.append(item) |
| valid_num_dic[i] -= 1 |
| else: |
| test_data.append(item) |
| print("train_length:" + str(len(train_data))) |
| print("test_length:" + str(len(test_data))) |
| return train_data, valid_data, test_data |
| |
| |
| if __name__ == "__main__": |
| data = read_data("./data/complete_data") |
| train_data, valid_data, test_data = split_data(data, 0.6, 0.2) |
| write_data(train_data, "./data/train") |
| write_data(valid_data, "./data/valid") |
| write_data(test_data, "./data/test") |