Commit 163adf28ec78be3d551592fa51ed7a50c0aa6dcd
1 parent
979a88c36e
Exists in
main
make it faster
Showing 1 changed file with 62 additions and 9 deletions Side-by-side Diff
python-notebook/prepare_trteva_data.ipynb
View file @
163adf2
... | ... | @@ -11,7 +11,9 @@ |
11 | 11 | "import os\n", |
12 | 12 | "from tools import *\n", |
13 | 13 | "from constants import *\n", |
14 | - "from tensorflow.keras.utils import to_categorical" | |
14 | + "from tensorflow.keras.utils import to_categorical\n", | |
15 | + "\n", | |
16 | + "# %load_ext line_profiler" | |
15 | 17 | ] |
16 | 18 | }, |
17 | 19 | { |
18 | 20 | |
... | ... | @@ -41,12 +43,59 @@ |
41 | 43 | "cell_type": "markdown", |
42 | 44 | "metadata": {}, |
43 | 45 | "source": [ |
46 | + "# Expanding one-hot-encoded gaits" | |
47 | + ] | |
48 | + }, | |
49 | + { | |
50 | + "cell_type": "code", | |
51 | + "execution_count": 13, | |
52 | + "metadata": {}, | |
53 | + "outputs": [ | |
54 | + { | |
55 | + "name": "stdout", | |
56 | + "output_type": "stream", | |
57 | + "text": [ | |
58 | + "(42360, 4) -> (127080, 4)\n" | |
59 | + ] | |
60 | + } | |
61 | + ], | |
62 | + "source": [ | |
63 | + "def mass_one_hot_encoding(padded_hours, colname, n_classes):\n", | |
64 | + " def __mass_one_hot_encoding(padded_hours, colname, n_classes, n):\n", | |
65 | + " temp = padded_hours[padded_hours[colname] == n]\n", | |
66 | + "\n", | |
67 | + " return_df = pd.DataFrame(dtype=int)\n", | |
68 | + "\n", | |
69 | + " for i in range(n_classes):\n", | |
70 | + " temp_2 = temp.copy(deep=True)\n", | |
71 | + " temp_2[\"var\"] = i\n", | |
72 | + " temp_2[\"value\"] = (n == i) if 1 else 0\n", | |
73 | + " return_df = pd.concat([return_df, temp_2], ignore_index=True)\n", | |
74 | + "\n", | |
75 | + " return return_df\n", | |
76 | + " \n", | |
77 | + " mass_encoded = pd.DataFrame(dtype=int)\n", | |
78 | + " for n in range(n_classes):\n", | |
79 | + " mass_encoded = pd.concat([mass_encoded, __mass_one_hot_encoding(padded_hours, colname, n_classes, n)], ignore_index=True)\n", | |
80 | + " return mass_encoded\n", | |
81 | + "\n", | |
82 | + "padded_hours_encoded = mass_one_hot_encoding(padded_hours, 'walked', 3)\n", | |
83 | + "padded_hours_encoded[\"local_date\"] = padded_hours_encoded[\"local_date\"].astype(str)\n", | |
84 | + "padded_hours_encoded = padded_hours_encoded.set_index(['user', 'local_date']).sort_index()\n", | |
85 | + "\n", | |
86 | + "print(\"{} -> {}\".format(padded_hours.shape, padded_hours_encoded.shape))" | |
87 | + ] | |
88 | + }, | |
89 | + { | |
90 | + "cell_type": "markdown", | |
91 | + "metadata": {}, | |
92 | + "source": [ | |
44 | 93 | "## Enumerating Output Data" |
45 | 94 | ] |
46 | 95 | }, |
47 | 96 | { |
48 | 97 | "cell_type": "code", |
49 | - "execution_count": 3, | |
98 | + "execution_count": 15, | |
50 | 99 | "metadata": {}, |
51 | 100 | "outputs": [], |
52 | 101 | "source": [ |
53 | 102 | |
... | ... | @@ -75,12 +124,13 @@ |
75 | 124 | " # gait movement\n", |
76 | 125 | " zero_move = 0\n", |
77 | 126 | " for a_date in date_range(start_date, end_date):\n", |
78 | - " day_df = padded_hours[(padded_hours[\"user\"] == user) & (padded_hours[\"local_date\"] == a_date)]\n", | |
79 | - " if day_df.size == 0:\n", | |
80 | - " gait = pd.concat([gait, pd.Series([1,0,0] * 24, dtype=int)])\n", | |
81 | - " zero_move += 1\n", | |
127 | + " key = (user, a_date.strftime(\"%Y-%m-%d\"))\n", | |
128 | + " if key in padded_hours_encoded.index:\n", | |
129 | + " day_df = padded_hours_encoded.loc[key, \"value\"]\n", | |
130 | + " gait = pd.concat([gait, day_df], ignore_index=True)\n", | |
82 | 131 | " else:\n", |
83 | - " gait = pd.concat([gait, pd.Series(to_categorical(day_df[\"walked\"].values, 3, dtype=int).reshape(24*3), dtype=int)])\n", | |
132 | + " gait = pd.concat([gait, pd.Series([1,0,0] * 24, dtype=int)], ignore_index=True)\n", | |
133 | + " zero_move += 1\n", | |
84 | 134 | " if zero_move == 5 * 7:\n", |
85 | 135 | " raise Exception(\"No movement data\")\n", |
86 | 136 | "\n", |
... | ... | @@ -116,7 +166,7 @@ |
116 | 166 | " temp_series = pd.concat([temp_series, pd.Series(threehour_idx, dtype=int)])\n", |
117 | 167 | " temp_series = pd.concat([temp_series, pd.Series(hour_idx, dtype=int)])\n", |
118 | 168 | " temp_series = pd.concat([temp_series, pd.Series(output, dtype=int)])\n", |
119 | - " temp_series = pd.concat([temp_series, pd.Series(input, dtype=int)])\n", | |
169 | + " temp_series = pd.concat([temp_series, pd.Series(input, dtype=int)]).reset_index(drop=True)\n", | |
120 | 170 | "\n", |
121 | 171 | " database = pd.concat([database, temp_series], axis=1)\n", |
122 | 172 | " # print(input)\n", |
... | ... | @@ -125,7 +175,10 @@ |
125 | 175 | " pass\n", |
126 | 176 | "\n", |
127 | 177 | " return database\n", |
128 | - "\n" | |
178 | + "\n", | |
179 | + "database = get_database(0, 100)\n", | |
180 | + "\n", | |
181 | + "database.to_pickle(os.path.join(data_dir, \"database.pkl\"))" | |
129 | 182 | ] |
130 | 183 | }, |
131 | 184 | { |