{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# t-SNE\n", "t-Distributed Stochastic Neighbor Embedding (t-SNE)\n", "\n", "Buen vídeo básico de introducción al t-SNE\n", "https://www.youtube.com/watch?v=NEaUSP4YerM\n", "\n", "https://towardsdatascience.com/an-introduction-to-t-sne-with-python-example-5a3a293108d1\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## t-SNE\n", "\n", "Para probar el método, vamos a crear un *dataset* formado por **tres grupos** de puntos generados con distintas localizaciones y varianzas." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "C1 = np.random.normal(loc=0., scale=1., size=(10,3))\n", "I1 = np.ones(10, dtype=int)\n", "\n", "#C2 = np.random.normal(loc=2., scale=0.1, size=(10,3))\n", "C2 = np.random.normal(loc=5., scale=1., size=(10,3))\n", "I2 = np.ones(10, dtype=int)*2\n", "\n", "#C3 = np.random.normal(loc=5, scale=2, size=(10,3))\n", "C3 = np.random.normal(loc=10, scale=1., size=(10,3))\n", "I3 = np.ones(10, dtype=int)*3\n", "\n", "\n", "df1 = pd.DataFrame(data=C1, columns=[\"x\", \"y\", \"z\"])\n", "df1I = pd.DataFrame(data=I1, columns=[\"class\"])\n", "\n", "df2 = pd.DataFrame(data=C2, columns=[\"x\", \"y\", \"z\"])\n", "df2I = pd.DataFrame(data=I2, columns=[\"class\"])\n", "\n", "df3 = pd.DataFrame(data=C3, columns=[\"x\", \"y\", \"z\"])\n", "df3I = pd.DataFrame(data=I3, columns=[\"class\"])\n", "\n", "\n", "result1 = pd.concat([df1, df1I], axis = 1, ignore_index=True, sort=False)\n", "result2 = pd.concat([df2, df2I], axis = 1, ignore_index=True, sort=False)\n", "result3 = pd.concat([df3, df3I], axis = 1, ignore_index=True, sort=False)\n", "\n", "df = pd.concat([result1, result2, result3])\n", "df.columns = [\"x\",\"y\", \"z\",\"class\"]\n", "df = df.reset_index(drop=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Veamos los puntos de forma tabulada.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x | \n", "y | \n", "z | \n", "class | \n", "
---|---|---|---|---|
0 | \n", "0.166524 | \n", "0.476437 | \n", "-0.748905 | \n", "1 | \n", "
1 | \n", "-0.260890 | \n", "-1.210117 | \n", "1.111706 | \n", "1 | \n", "
2 | \n", "1.483352 | \n", "0.980338 | \n", "2.013879 | \n", "1 | \n", "
3 | \n", "-0.008549 | \n", "0.257644 | \n", "-1.536449 | \n", "1 | \n", "
4 | \n", "-1.467212 | \n", "-0.370246 | \n", "0.176544 | \n", "1 | \n", "
5 | \n", "-0.035987 | \n", "0.426323 | \n", "-1.165189 | \n", "1 | \n", "
6 | \n", "-1.098647 | \n", "0.331815 | \n", "1.964918 | \n", "1 | \n", "
7 | \n", "-0.272549 | \n", "-0.040916 | \n", "-0.370012 | \n", "1 | \n", "
8 | \n", "2.678083 | \n", "0.933871 | \n", "-0.728749 | \n", "1 | \n", "
9 | \n", "1.314680 | \n", "0.550044 | \n", "-1.232873 | \n", "1 | \n", "
10 | \n", "5.063146 | \n", "4.676758 | \n", "6.091572 | \n", "2 | \n", "
11 | \n", "4.005999 | \n", "4.868296 | \n", "5.885338 | \n", "2 | \n", "
12 | \n", "4.395932 | \n", "5.603112 | \n", "5.080207 | \n", "2 | \n", "
13 | \n", "5.908465 | \n", "5.186550 | \n", "6.903519 | \n", "2 | \n", "
14 | \n", "3.197914 | \n", "5.841706 | \n", "4.253857 | \n", "2 | \n", "
15 | \n", "5.017807 | \n", "4.312943 | \n", "3.940999 | \n", "2 | \n", "
16 | \n", "3.890963 | \n", "6.252370 | \n", "5.864753 | \n", "2 | \n", "
17 | \n", "4.683445 | \n", "5.941964 | \n", "4.850334 | \n", "2 | \n", "
18 | \n", "4.606564 | \n", "6.254284 | \n", "4.996100 | \n", "2 | \n", "
19 | \n", "3.944945 | \n", "4.605639 | \n", "5.809857 | \n", "2 | \n", "
20 | \n", "10.027336 | \n", "11.041746 | \n", "10.127516 | \n", "3 | \n", "
21 | \n", "8.923284 | \n", "9.785343 | \n", "11.444244 | \n", "3 | \n", "
22 | \n", "8.983430 | \n", "9.807272 | \n", "9.374686 | \n", "3 | \n", "
23 | \n", "8.087111 | \n", "8.450916 | \n", "8.720797 | \n", "3 | \n", "
24 | \n", "9.362836 | \n", "9.986216 | \n", "8.472149 | \n", "3 | \n", "
25 | \n", "10.091924 | \n", "9.338093 | \n", "9.719824 | \n", "3 | \n", "
26 | \n", "8.936871 | \n", "11.359872 | \n", "10.348214 | \n", "3 | \n", "
27 | \n", "10.474326 | \n", "11.126252 | \n", "10.260523 | \n", "3 | \n", "
28 | \n", "9.603450 | \n", "9.553276 | \n", "10.102165 | \n", "3 | \n", "
29 | \n", "8.888365 | \n", "10.522096 | \n", "8.616896 | \n", "3 | \n", "
Q_points
será nuestro conjunto de puntos tridimensionales y P_points
será nuestro conjunto de puntos bidimensionales con los que intentaremos visualizar el primer conjunto."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"Q_points = df.values[:,0:3] # Tomamos las coordenadas de los puntos, no la clase\n",
"P_points = np.random.uniform(0, 10, size=(30,2)) # Generamos puntos al azar"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Si visualizamos el conjunto de puntos bidimensional, veremos que está totalmente desordenado (está generado al azar)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"P
y Q
. Los puntos P
son los que tenemos que \"mover\", por eso requieren gradiente."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"device = 'cpu'\n",
"\n",
"P = torch.tensor(P_points, requires_grad=True, dtype=torch.float, device=device)\n",
"Q = torch.tensor(Q_points, dtype=torch.float, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 loss: 39.479583740234375\n",
"Epoch: 10 loss: 15.956965446472168\n",
"Epoch: 20 loss: 13.391416549682617\n",
"Epoch: 30 loss: 12.188273429870605\n",
"Epoch: 40 loss: 10.87851333618164\n",
"Epoch: 50 loss: 10.16821575164795\n",
"Epoch: 60 loss: 9.989543914794922\n",
"Epoch: 70 loss: 9.744491577148438\n",
"Epoch: 80 loss: 9.489737510681152\n",
"Epoch: 90 loss: 8.615865707397461\n",
"Epoch: 100 loss: 8.053683280944824\n",
"Epoch: 110 loss: 7.864839553833008\n",
"Epoch: 120 loss: 7.308864593505859\n",
"Epoch: 130 loss: 6.854788780212402\n",
"Epoch: 140 loss: 5.771241188049316\n",
"Epoch: 150 loss: 4.689131736755371\n",
"Epoch: 160 loss: 4.214322566986084\n",
"Epoch: 170 loss: 4.072487831115723\n",
"Epoch: 180 loss: 4.051445007324219\n",
"Epoch: 190 loss: 4.03515625\n",
"Epoch: 200 loss: 4.021659851074219\n",
"Epoch: 210 loss: 4.013899803161621\n",
"Epoch: 220 loss: 3.9985787868499756\n",
"Epoch: 230 loss: 3.986757516860962\n",
"Epoch: 240 loss: 3.977696180343628\n",
"Epoch: 250 loss: 3.969468593597412\n",
"Epoch: 260 loss: 3.959566116333008\n",
"Epoch: 270 loss: 3.9487476348876953\n",
"Epoch: 280 loss: 3.9359965324401855\n",
"Epoch: 290 loss: 3.919861078262329\n",
"Epoch: 300 loss: 3.9039435386657715\n",
"Epoch: 310 loss: 3.8863189220428467\n",
"Epoch: 320 loss: 3.8603057861328125\n",
"Epoch: 330 loss: 3.8084962368011475\n",
"Epoch: 340 loss: 3.7537732124328613\n",
"Epoch: 350 loss: 3.695063591003418\n",
"Epoch: 360 loss: 3.590381383895874\n",
"Epoch: 370 loss: 3.4406182765960693\n",
"Epoch: 380 loss: 3.263678789138794\n",
"Epoch: 390 loss: 2.823730707168579\n",
"Epoch: 400 loss: 2.2120273113250732\n",
"Epoch: 410 loss: 0.6396467685699463\n",
"Epoch: 420 loss: 0.5637612342834473\n",
"Epoch: 430 loss: 0.5009985566139221\n",
"Epoch: 440 loss: 0.4680554270744324\n",
"Epoch: 450 loss: 0.4548220932483673\n",
"Epoch: 460 loss: 0.43580907583236694\n",
"Epoch: 470 loss: 0.436868280172348\n",
"Epoch: 480 loss: 0.4374592900276184\n",
"Epoch: 490 loss: 0.4361301362514496\n"
]
}
],
"source": [
"optimizer = torch.optim.RMSprop([P], lr=0.1)\n",
"\n",
"epochs = 500\n",
"\n",
"for k in range(epochs):\n",
" \n",
" PM = torch.tensor(np.zeros((len(P_points), len(P_points))), dtype=float, device=device)\n",
" QM = torch.tensor(np.zeros((len(P_points), len(P_points))), dtype=float, device=device)\n",
" \n",
" for i, q_row in enumerate(Q):\n",
" for j, q_column in enumerate(Q):\n",
" QM[i, j] = distance(q_row, q_column)\n",
" \n",
" QD = pdf(QM)\n",
" QD = torch.div(QD.t(), torch.sum(QD, dim=1)).t()\n",
" \n",
" for i, p in enumerate(P):\n",
" for j, q in enumerate(P):\n",
" PM[i, j] = distance(p, q)\n",
"\n",
" PD = pdf(PM)\n",
" PD = torch.div(PD.t(), torch.sum(PD, dim=1)).t()\n",
" \n",
" loss = torch.tensor([0.], device=device)\n",
" for pd, qd in zip(PD, QD):\n",
" loss += KL(pd, qd)\n",
" \n",
" if k%10 == 0:\n",
" print(\"Epoch:\", k, \"loss:\", loss.item())\n",
" \n",
"\n",
" loss.backward() \n",
" optimizer.step() \n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualizamos el resultado:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"