Commit 163adf28ec78be3d551592fa51ed7a50c0aa6dcd

Authored by Junghwan Park
1 parent 979a88c36e
Exists in main

make it faster

Showing 1 changed file with 62 additions and 9 deletions Side-by-side Diff

python-notebook/prepare_trteva_data.ipynb View file @ 163adf2
... ... @@ -11,7 +11,9 @@
11 11 "import os\n",
12 12 "from tools import *\n",
13 13 "from constants import *\n",
14   - "from tensorflow.keras.utils import to_categorical"
  14 + "from tensorflow.keras.utils import to_categorical\n",
  15 + "\n",
  16 + "# %load_ext line_profiler"
15 17 ]
16 18 },
17 19 {
18 20  
... ... @@ -41,12 +43,59 @@
41 43 "cell_type": "markdown",
42 44 "metadata": {},
43 45 "source": [
  46 + "# Expanding one-hot-encoded gaits"
  47 + ]
  48 + },
  49 + {
  50 + "cell_type": "code",
  51 + "execution_count": 13,
  52 + "metadata": {},
  53 + "outputs": [
  54 + {
  55 + "name": "stdout",
  56 + "output_type": "stream",
  57 + "text": [
  58 + "(42360, 4) -> (127080, 4)\n"
  59 + ]
  60 + }
  61 + ],
  62 + "source": [
  63 + "def mass_one_hot_encoding(padded_hours, colname, n_classes):\n",
  64 + " def __mass_one_hot_encoding(padded_hours, colname, n_classes, n):\n",
  65 + " temp = padded_hours[padded_hours[colname] == n]\n",
  66 + "\n",
  67 + " return_df = pd.DataFrame(dtype=int)\n",
  68 + "\n",
  69 + " for i in range(n_classes):\n",
  70 + " temp_2 = temp.copy(deep=True)\n",
  71 + " temp_2[\"var\"] = i\n",
  72 + " temp_2[\"value\"] = (n == i) if 1 else 0\n",
  73 + " return_df = pd.concat([return_df, temp_2], ignore_index=True)\n",
  74 + "\n",
  75 + " return return_df\n",
  76 + " \n",
  77 + " mass_encoded = pd.DataFrame(dtype=int)\n",
  78 + " for n in range(n_classes):\n",
  79 + " mass_encoded = pd.concat([mass_encoded, __mass_one_hot_encoding(padded_hours, colname, n_classes, n)], ignore_index=True)\n",
  80 + " return mass_encoded\n",
  81 + "\n",
  82 + "padded_hours_encoded = mass_one_hot_encoding(padded_hours, 'walked', 3)\n",
  83 + "padded_hours_encoded[\"local_date\"] = padded_hours_encoded[\"local_date\"].astype(str)\n",
  84 + "padded_hours_encoded = padded_hours_encoded.set_index(['user', 'local_date']).sort_index()\n",
  85 + "\n",
  86 + "print(\"{} -> {}\".format(padded_hours.shape, padded_hours_encoded.shape))"
  87 + ]
  88 + },
  89 + {
  90 + "cell_type": "markdown",
  91 + "metadata": {},
  92 + "source": [
44 93 "## Enumerating Output Data"
45 94 ]
46 95 },
47 96 {
48 97 "cell_type": "code",
49   - "execution_count": 3,
  98 + "execution_count": 15,
50 99 "metadata": {},
51 100 "outputs": [],
52 101 "source": [
53 102  
... ... @@ -75,12 +124,13 @@
75 124 " # gait movement\n",
76 125 " zero_move = 0\n",
77 126 " for a_date in date_range(start_date, end_date):\n",
78   - " day_df = padded_hours[(padded_hours[\"user\"] == user) & (padded_hours[\"local_date\"] == a_date)]\n",
79   - " if day_df.size == 0:\n",
80   - " gait = pd.concat([gait, pd.Series([1,0,0] * 24, dtype=int)])\n",
81   - " zero_move += 1\n",
  127 + " key = (user, a_date.strftime(\"%Y-%m-%d\"))\n",
  128 + " if key in padded_hours_encoded.index:\n",
  129 + " day_df = padded_hours_encoded.loc[key, \"value\"]\n",
  130 + " gait = pd.concat([gait, day_df], ignore_index=True)\n",
82 131 " else:\n",
83   - " gait = pd.concat([gait, pd.Series(to_categorical(day_df[\"walked\"].values, 3, dtype=int).reshape(24*3), dtype=int)])\n",
  132 + " gait = pd.concat([gait, pd.Series([1,0,0] * 24, dtype=int)], ignore_index=True)\n",
  133 + " zero_move += 1\n",
84 134 " if zero_move == 5 * 7:\n",
85 135 " raise Exception(\"No movement data\")\n",
86 136 "\n",
... ... @@ -116,7 +166,7 @@
116 166 " temp_series = pd.concat([temp_series, pd.Series(threehour_idx, dtype=int)])\n",
117 167 " temp_series = pd.concat([temp_series, pd.Series(hour_idx, dtype=int)])\n",
118 168 " temp_series = pd.concat([temp_series, pd.Series(output, dtype=int)])\n",
119   - " temp_series = pd.concat([temp_series, pd.Series(input, dtype=int)])\n",
  169 + " temp_series = pd.concat([temp_series, pd.Series(input, dtype=int)]).reset_index(drop=True)\n",
120 170 "\n",
121 171 " database = pd.concat([database, temp_series], axis=1)\n",
122 172 " # print(input)\n",
... ... @@ -125,7 +175,10 @@
125 175 " pass\n",
126 176 "\n",
127 177 " return database\n",
128   - "\n"
  178 + "\n",
  179 + "database = get_database(0, 100)\n",
  180 + "\n",
  181 + "database.to_pickle(os.path.join(data_dir, \"database.pkl\"))"
129 182 ]
130 183 },
131 184 {