{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# COURS OPTIMISATION CONTINUE - Antonin Chambolle - 2020: Examples of stochastic algorithms : coordinate descent, gradient descent.\n", "\n", "## Classification of handwritten digits using a \"Support Vector Machine\" and stochastic optimization**\n", "\n", "Digits from the MNIST Database, [LeCun et al., 1998a: Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. \"Gradient-based learning applied to document recognition.\" Proceedings of the IEEE, 86(11):2278-2324, November 1998.]\n", "\n", "We will show in this notebook how to solve an easy convex problems either by stochastic gradient descent or (proximal) stochastic block coordinate descent. We will address a very simplify task which is distinguishing a subset of one or several digits from another one. \n", "\n", "### The problem:\n", "\n", "In practice, the learning set is made of a set of $N$ vectors $(X_i)$ (here, $28\\times 28=784$ small images representing handwritten digits) and of labels $y_i\\in \\{-1,1\\}$. (Initially, the labels are in $\\{0,\\dots,9\\}$ and indicate the value of the handwritten digits. We define two lists \"mylabelplus\" and \"mylabelminus\" of digits, the first list will be labelled as $+1$ and the second as $-1$.)\n", "\n", "One wants to learn a hyperplane best separating $\\{X_i : y_i=+1\\}$ and $\\{X_i : y_i=-1\\}$ in $\\mathbb{R}^{784}$. That is, we want to find $(w,b)\\in\\mathbb{R}^{784}\\times \\mathbb{R}$ such that $w\\cdot X_i>b$ for $y_i=1$\n", "and $b+1$ each time $y_i=+1$\n", "and $w\\cdot X_i\n", "\n", "\n", "In practice,$h(t) = (1-t)^+$so that$h^*(s) = s$if$s\\in [-1,0]$and$+\\infty$else. We will also used\n", "a regularized version, as follows: for$\\varepsilon>0$, we let\n", "$h_{\\varepsilon}^*(s) = s + \\varepsilon s^2/2$for$s\\in [-1,0]$and$+\\infty$else. Show in this case that\n", "$h_\\varepsilon'$is$1/\\varepsilon$Lipschitz and that:\n", "$$\n", "h_\\varepsilon(t) = \\begin{cases} 1-t-\\frac{\\varepsilon}{2} & \\text{ if } t\\le 1-\\varepsilon \\\\\n", "\\frac{(t-1)^2}{2\\varepsilon} & \\text{ if } 1-\\varepsilon \\le t \\le 1 \\\\\n", "0 & \\text{ if } t\\ge 1\n", "\\end{cases}\\quad ,\\quad\n", "h'_\\varepsilon(t) = \\begin{cases} -1 & \\text{ if } t\\le 1-\\varepsilon \\\\\n", "\\frac{t-1}{\\varepsilon} & \\text{ if } 1-\\varepsilon \\le t \\le 1 \\\\\n", "0 & \\text{ if } t\\ge 1\n", "\\end{cases}\n", "$$\n", "We will normalize the images so that$|X_i|=1$. In the dual problem we find that the derivative of the quadratic term of energy with respect to$z_i$is$1$-Lipschitz. \n", "In the primal problem, if we replace$h$with$h_\\varepsilon$, we find that the Lipschitz constant of\n", "the gradient of the objective (which is now$C^1$) with respect to$w$is at most$1+\\lambda/\\varepsilon\$.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import random\n", "import gzip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we download the MNIST database. We used the .csv files in https://www.python-course.eu/neural_network_mnist.php\n", "and then installed these in the directory data/mnist/: please adapt the code by setting the variable data_path below accordingly. The code in the next cell is copy-pased from the above link." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load MNIST data \n", "\n", "# training data\n", "mnist_path = \"data/mnist/\"\n", "train_images_path = mnist_path+\"train-images-idx3-ubyte.gz\"\n", "with gzip.open(train_images_path, 'rb') as train_imgpath:\n", " train_imgs = np.frombuffer(train_imgpath.read(), dtype=np.uint8,\n", " offset=16).reshape(-1, 784)/255.0\n", " \n", "train_labels_path = mnist_path+\"train-labels-idx1-ubyte.gz\"\n", "with gzip.open(train_labels_path, 'rb') as train_lbpath:\n", " train_labels = np.frombuffer(train_lbpath.read(), dtype=np.uint8,\n", " offset=8)\n", " \n", "# test data\n", "test_images_path = mnist_path+\"t10k-images-idx3-ubyte.gz\"\n", "with gzip.open(test_images_path, 'rb') as test_imgpath:\n", " test_imgs = np.frombuffer(test_imgpath.read(), dtype=np.uint8,\n", " offset=16).reshape(-1, 784)/255.0\n", " \n", "test_labels_path = mnist_path+\"t10k-labels-idx1-ubyte.gz\"\n", "with gzip.open(test_labels_path, 'rb') as test_lbpath:\n", " test_labels = np.frombuffer(test_lbpath.read(), dtype=np.uint8,\n", " offset=8)\n", "# remove spurious dimensions\n", "train_labels=np.squeeze(train_labels)\n", "test_labels=np.squeeze(test_labels)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# reduce the size of train set if needed\n", "# this might be necessary for the random coordinate ascent which starts \n", "# by computing a N*N matrix of scalar products. We convert before this the\n", "# values to float32 (4 bytes), yet N=60000 values (the size of the training set)\n", "# require 14.4 Gb\n", "# depending on the number of figures selected below and your memory you should\n", "# adjust the value of r (the coordinate descent on the primal does not need this)\n", "# (the speed of your processor can also be a bottleneck, computing the matrix is best\n", "# with many cores)\n", "\n", "r = 60/100 # 30% \n", "\n", "indices = (np.random.random(train_labels.shape)" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAN9ElEQVR4nO3df6hcdXrH8c/HGBXNBmNztTf+ylYEjYVml0EK/iBladBgjBEs+oeoUaOg4EKUBvuHEUWkdA0KZfFuFZO6dVV2gwZCu0EWo/+YjCE1SUMbldTNemPuJcgqKjbJ0z/ucbnGO2euM2fmTPK8X3CZmfPMOefJ5H7umZnvmfk6IgTgxHdS3Q0A6A/CDiRB2IEkCDuQBGEHkji5nzubO3duzJ8/v5+7BFLZt2+fxsfHPVWtq7DbvkbS05JmSPqXiHiy7P7z589Xs9nsZpcASjQajZa1jp/G254h6Z8lXStpgaRbbC/odHsAequb1+yXS3o/Ij6MiK8l/UrSsmraAlC1bsJ+rqTfT7q9v1j2LbZX2m7abo6NjXWxOwDd6CbsU70J8J1zbyNiJCIaEdEYGhrqYncAutFN2PdLOn/S7fMkfdxdOwB6pZuwb5N0se0f2j5F0s2SXq+mLQBV63joLSIO275f0n9oYujt+YjYXVlnACrV1Th7RGyStKmiXgD0EKfLAkkQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kERXs7hiMIyPj7esHT58uHTdrVu3ltaXLVtWWj/ppME9Xtxxxx0ta88++2zpujNmzKi6ndp1FXbb+yR9JumIpMMR0aiiKQDVq+LI/jcR0frQAmAgDO5zMACV6jbsIem3tt+1vXKqO9heabtpuzk2Ntbl7gB0qtuwXxERP5Z0raT7bF997B0iYiQiGhHRGBoa6nJ3ADrVVdgj4uPi8qCkDZIur6IpANXrOOy2z7D9g2+uS1osaVdVjQGoVjfvxp8jaYPtb7bzbxHx75V0lcyBAwdK6+vXry+tj4yMtKwdPXq0dN2PPvqotN5uHL34/x9IL7zwQsvanDlzStd9/PHHS+unnnpqJy3VquOwR8SHkv6qwl4A9BBDb0AShB1IgrADSRB2IAnCDiTBR1wHwOrVq0vrL774Yp86yWPt2rWl9Xvvvbe0ftFFF1XZTl9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnHwBLly4trXczzj5v3rzS+oMPPlhab/cR2W6+Svqtt94qrW/YsKHjbeO7OLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsw+A5cuXl9YPHTrU8bbbjYPPmjWr421365577imtX3rppaX1dl+DXWbFihWl9QsvvLDjbQ8qjuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7AOg3Vj47Nmz+9RJf23fvr20Pj4+3rN9X3DBBaX1k08+8aLR9shu+3nbB23vmrTsLNubbe8tLssnuwZQu+k8jX9B0jXHLFst6Y2IuFjSG8VtAAOsbdgjYoukY8/XXCZpXXF9naQbKu4LQMU6fYPunIgYlaTi8uxWd7S90nbTdnNsbKzD3QHoVs/fjY+IkYhoRERjaGio17sD0EKnYf/E9rAkFZcHq2sJQC90GvbXJd1WXL9N0mvVtAOgV9oOJtp+SdIiSXNt75f0iKQnJb1i+05JH0m6qZdN4vj19ttvt6w9/fTTpet+8cUXVbfzJw899FDPtj2o2oY9Im5pUfpJxb0A6CFOlwWSIOxAEoQdSIKwA0kQdiCJE+9zfKjUli1bSuurVq0qre/evbtl7euvv+6op+m66qqrWta6mWr6eJXvXwwkRdiBJAg7kARhB5Ig7EAShB1IgrADSTDOPgA+/fTT0vorr7xSWt+0aVOV7XzLxo0bS+u2e7bvM888s7S+fv360vqVV17ZsjZz5syOejqecWQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZ++D0dHR0vqiRYtK6x988EGF3Rw/li5dWlpfsmRJnzo5MXBkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGcfABHRVb2Xjh49Wlrv5fevt/u8+gMPPFBaX7hwYZXtHPfa/k/Zft72Qdu7Ji1bY/sPtncUP5zdAAy46fxZfkHSNVMsXxsRC4uf3n1VCoBKtA17RGyRdKgPvQDooW5ecN1v+73iaf6cVneyvdJ203ZzbGysi90B6EanYf+5pIskLZQ0Kulnre4YESMR0YiIxtDQUIe7A9CtjsIeEZ9ExJGIOCrpF5Iur7YtAFXrKOy2hyfdXC5pV6v7AhgMbcfZbb8kaZGkubb3S3pE0iLbCyWFpH2S7ulhj8e94eHh0vq2bdtK66+++mppffHixS1rp5xySum6vfbcc8+1rD3yyCN97ARtwx4Rt0yxuPX/IICBxOmyQBKEHUiCsANJEHYgCcIOJOF+fnyy0WhEs9ns2/5Qv6+++qplbdasWV1tu93vUsaPuDYaDTWbzSnn0ebIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ8FXS6Knt27fX3QIKHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2afpyJEjLWs7d+4sXfeyyy4rrc+cObOjngbB5s2bS+s33XRTnzpBOxzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtkLe/fuLa2vWbOmZe3ll18uXffQoUOl9TrH2b/88svS+tatW0vrN998c2n9888//949feP0008vrZ922mkdbzujtkd22+fb/p3tPbZ3236gWH6W7c229xaXc3rfLoBOTedp/GFJqyLiUkl/Lek+2wskrZb0RkRcLOmN4jaAAdU27BExGhHbi+ufSdoj6VxJyyStK+62TtINvWoSQPe+1xt0tudL+pGkdySdExGj0sQfBElnt1hnpe2m7ebY2Fh33QLo2LTDbnuWpF9L+mlE/HG660XESEQ0IqIxNDTUSY8AKjCtsNueqYmg/zIiflMs/sT2cFEflnSwNy0CqELboTfblvScpD0R8dSk0uuSbpP0ZHH5Wk867JPbb7+9tP7OO+90vO21a9eW1mfPnt3xtru1cePG0vqbb75ZWp/49ejMjTfeWFpftWpVaf2SSy7peN8ZTWec/QpJt0raaXtHsexhTYT8Fdt3SvpIEh9cBgZY27BHxNuSWv35/km17QDoFU6XBZIg7EAShB1IgrADSRB2IAk+4toHjz32WN0t9My8efNK67feemvL2qOPPlq67skn8+tZJY7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEA5mFdl8H/cwzz7SsPfXUUy1rdVuwYEFpvd1n6RcvXlxav/vuu0vrw8PDpXX0D0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfbCeeedV1p/4oknWtauvvrq0nXvuuuu0vr4+HhpfcWKFaX166+/vmVt0aJFpevOmjWrtI4TB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUhiOvOzny9pvaQ/l3RU0khEPG17jaS7JY0Vd304Ijb1qtG6lX2H+XXXXVe67oEDB6puB/jepnNSzWFJqyJiu+0fSHrX9uaitjYi/ql37QGoynTmZx+VNFpc/8z2Hknn9roxANX6Xq/Zbc+X9CNJ7xSL7rf9nu3nbc9psc5K203bzbGxsanuAqAPph1227Mk/VrSTyPij5J+LukiSQs1ceT/2VTrRcRIRDQiojE0NFRBywA6Ma2w256piaD/MiJ+I0kR8UlEHImIo5J+Ieny3rUJoFttw27bkp6TtCcinpq0fPLXhi6XtKv69gBUZTrvxl8h6VZJO23vKJY9LOkW2wslhaR9ku7pSYcAKjGdd+PfluQpSifsmDpwIuIMOiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKOiP7tzB6T9L+TFs2VVD5fcX0GtbdB7Uuit05V2duFETHl97/1Nezf2bndjIhGbQ2UGNTeBrUvid461a/eeBoPJEHYgSTqDvtIzfsvM6i9DWpfEr11qi+91fqaHUD/1H1kB9AnhB1Iopaw277G9n/bft/26jp6aMX2Pts7be+w3ay5l+dtH7S9a9Kys2xvtr23uJxyjr2aeltj+w/FY7fD9pKaejvf9u9s77G92/YDxfJaH7uSvvryuPX9NbvtGZL+R9LfStovaZukWyLiv/raSAu290lqRETtJ2DYvlrS55LWR8RfFsv+UdKhiHiy+EM5JyL+fkB6WyPp87qn8S5mKxqePM24pBsk3a4aH7uSvv5OfXjc6jiyXy7p/Yj4MCK+lvQrSctq6GPgRcQWSYeOWbxM0rri+jpN/LL0XYveBkJEjEbE9uL6Z5K+mWa81seupK++qCPs50r6/aTb+zVY872HpN/aftf2yrqbmcI5ETEqTfzySDq75n6O1XYa7346ZprxgXnsOpn+vFt1hH2qqaQGafzvioj4saRrJd1XPF3F9ExrGu9+mWKa8YHQ6fTn3aoj7PslnT/p9nmSPq6hjylFxMfF5UFJGzR4U1F/8s0MusXlwZr7+ZNBmsZ7qmnGNQCPXZ3Tn9cR9m2SLrb9Q9unSLpZ0us19PEdts8o3jiR7TMkLdbgTUX9uqTbiuu3SXqtxl6+ZVCm8W41zbhqfuxqn/48Ivr+I2mJJt6R/0DSP9TRQ4u+/kLSfxY/u+vuTdJLmnha93+aeEZ0p6Q/k/SGpL3F5VkD1Nu/Stop6T1NBGu4pt6u1MRLw/ck7Sh+ltT92JX01ZfHjdNlgSQ4gw5IgrADSRB2IAnCDiRB2IEkCDuQBGEHkvh/S1oWzvuBodgAAAAASUVORK5CYII=\n", "text/plain": [ "