{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Line-search methods" ] }, { "cell_type": "code", "execution_count": 411, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pylab as pl\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['figure.dpi']= 100 # parameter for resolution of graphics\n", "import time\n", "\n", "v = 2 # variant corresponding to the number of the function below\n", "Maxiter = 200 # Number of iterations\n", "x0 = 2 # Initialization\n", "a = 0 # Lower bound for the plot interval\n", "b = 4 # Upper bound for the plot interval\n", "Tol = 1e-15\n", "\n", "InitStep = 1 # initial step\n", "m1 = 0.1 #Parameters\n", "m2 = 0.7" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Various objective functions" ] }, { "cell_type": "code", "execution_count": 412, "metadata": {}, "outputs": [], "source": [ "def fun(x,v): # function definition\n", " if v==0:\n", " return 3*(1.5*x-1)**2\n", " if v==1:\n", " return x**3-5*x+1\n", " if v==2:\n", " return x**4-2*x**3-5*x\n", " if v==3:\n", " return np.cos(5*x)-8*x+2.5*x**2\n", " if v==4:\n", " return np.cos(5*x)-8*x+1.5*x**2\n", "\n", "def der(x,v): # first derivative\n", " if v==0:\n", " return 3*2*(1.5*x-1)*1.5\n", " if v==1:\n", " return 3*x**2-5\n", " if v==2:\n", " return 4*x**3-6*x**2-5\n", " if v==3:\n", " return -5*np.sin(5*x)-8+2.5*2*x\n", " if v==4:\n", " return -5*np.sin(5*x)-8+1.5*2*x\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Algorithm: Generic Line Search\n", "\n", "
\n", "\n", " **Initialization:** Start with $t_l=t_r=0$ and pick an initial $t>0$.\n", "\n", " **Iterate:** \n", " - **Step 1**\n", " - **if** (a) then exit: you found a good $t$\n", " - **if** (b) then $t_r = t$: you found a new upper bound for $t$\n", " - **if** (c) then $t_l = t$: you found a good new lower bound for $t$\n", " - **Step 2**\n", " - **if** no valid $t_r$ exists then choose a new $t>t_r$\n", " - **else** choose a new $t \\in (t_l,t_r)$\n", "
\n", "\n", "\n", "\n", "\n", "\n", "## Goldstein-Price\n", "\n", "\n", "
\n", "\n", "Recall the three conditions for the **Goldstein-Price line-search**\n", "\n", "$\\star$ $m_1,m_2 \\in (0,1)$ are chosen constants such that $m_1<0.5$ and $m_2>0.5$.\n", " \n", "(a) $m_2 q'(0) \\leq \\frac{q(t)-q(0)}{t} \\leq m_1 q'(0)$ (then we have a good $t$)\n", "\n", "(b) $m_1q'(0)< \\frac{q(t)-q(0)}{t}$ (then $t$ is too big)\n", "\n", "(c) $\\frac{q(t)-q(0)}{t}\n", "\n" ] }, { "cell_type": "code", "execution_count": 413, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0\n", "0 4\n", "2.0 4\n", "2.0 3.0\n", "2.5\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def LinesearchGP(q,dq,m1,m2): \n", " tl = 0\n", " tr = 0\n", " t = b\n", " qp = dq(0,v)\n", " q0 = fun(0,v)\n", " xs = np.linspace(a,b)\n", " ys = q0+m1*qp*xs\n", " ys2 = q0+m2*qp*xs\n", " plt.plot(xs,ys,label=\"m1\")\n", " plt.plot(xs,ys2,label=\"m2\")\n", " qt = q(t,v)\n", " plt.plot(t,qt,'.g')\n", " while (1==1):\n", " qt = q(t,v)\n", " plt.plot(t,qt,'xr')\n", " print(tl,\" \",tr)\n", " if ((qt-q0)/t<=(m1*qp)) and ((qt-q0)/t>=(m2*qp)):\n", " step=t # we found a good step\n", " break\n", " if ((qt-q0)/t>(m1*qp)):\n", " # step too big\n", " tr = t\n", " if ((qt-q0)/t<(m2*qp)):\n", " # step too small\n", " tl = t\n", " if(tr==0):\n", " t = 2*tl\n", " else:\n", " t = 0.5*(tl+tr)\n", " if (tr-tl)<1e-15:\n", " break\n", " print(t)\n", " plt.plot(t,q(t,v),'xb')\n", "\n", "uplim = max(fun(a,v),fun(b,v))+1 # set limits for the plot window\n", "dnlim = -3\n", "t1 = np.linspace(a,b,100) # Create a discretization to be used with the plots\n", "plt.figure(1)\n", "plt.plot(t1,fun(t1,v),'k') # Plot the function to be optimized on the interval [a,b]\n", "\n", "LinesearchGP(fun,der,m1,m2)\n", "plt.title('Goldstein-Price')\n", "plt.legend()\n", "plt.show() \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Armijo\n", "\n", "
\n", "\n", "Recall the three conditions for the **Armijo line-search**\n", "\n", "$\\star$ $m_1,m_2 \\in (0,1)$ are chosen constants such that $m_1<0.5$ and $m_2>0.5$.\n", "\n", "(a) $\\frac{q(t)-q(0)}{t} \\leq m_1 q'(0)$ (then we have a good $t$)\n", "\n", "(b) $m_1q'(0)< \\frac{q(t)-q(0)}{t}$ (then $t$ is too big)\n", "\n", "(c) Never. You may take $t_l=0$ always.\n", "
" ] }, { "cell_type": "code", "execution_count": 414, "metadata": {}, "outputs": [], "source": [ "def LinesearchArmijo(q,dq,m1): \n", " tl = 0\n", " tr = 0\n", " t = b\n", " qp = dq(0,v)\n", " q0 = fun(0,v)\n", " xs = np.linspace(a,b)\n", " ys = q0+m1*qp*xs\n", " plt.plot(xs,ys,label=\"m1\")\n", " qt = q(t,v)\n", " plt.plot(t,qt,'.g')\n", " while (1==1):\n", " qt = q(t,v)\n", " plt.plot(t,qt,'xr')\n", " print(tl,\" \",tr)\n", " if ((qt-q0)/t<=(m1*qp)):\n", " step=t # we found a good step\n", " break\n", " if ((qt-q0)/t>(m1*qp)):\n", " # step too big\n", " tr = t\n", " if(tr==0):\n", " t = 2*tl\n", " else:\n", " t = 0.5*(tl+tr)\n", " if (tr-tl)<1e-15:\n", " break\n", " plt.plot(t,q(t,v),'xb')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "\n" ] }, { "cell_type": "code", "execution_count": 415, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0\n", "0 4\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "uplim = max(fun(a,v),fun(b,v))+1 # set limits for the plot window\n", "dnlim = -3\n", "t1 = np.linspace(a,b,100) # Create a discretization to be used with the plots\n", "plt.figure(1)\n", "plt.plot(t1,fun(t1,v),'k') # Plot the function to be optimized on the interval [a,b]\n", "\n", "LinesearchArmijo(fun,der,m1)\n", "plt.title('Armijo')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wolfe\n", "\n", "
\n", "\n", "Recall the three conditions for the **Wolfe line-search**\n", "\n", "$\\star$ $m_1,m_2 \\in (0,1)$ are chosen constants such that $m_1<0.5$ and $m_2>0.5$.\n", "\n", "(a) $\\frac{q(t)-q(0)}{t} \\leq m_1 q'(0)$ and $q'(t) \\geq m_2 q'(0)$ (then we have a good $t$)\n", "\t\n", "(b) $\\frac{q(t)-q(0)}{t}> m_1q'(0) $ (then $t$ is too big)\n", "\t\n", "(c) $\\frac{q(t)-q(0)}{t}\\leq m_1q'(0)$ and $q'(t) < m_2 q'(0)$ (then $t$ is too small)\n", "
" ] }, { "cell_type": "code", "execution_count": 417, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 0\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def LinesearchWolfe(q,dq,m1,m2): \n", " tl = 0\n", " tr = 0\n", " t = 1\n", " qp = dq(0,v)\n", " q0 = fun(0,v)\n", " xs = np.linspace(a,b)\n", " ys = q0+m1*qp*xs\n", " plt.plot(xs,ys,label=\"m1\")\n", " qt = q(t,v)\n", " plt.plot(t,qt,'.g')\n", " while (1==1):\n", " qt = q(t,v)\n", " plt.plot(t,qt,'xr')\n", " print(tl,\" \",tr)\n", " if ((qt-q0)/t<=(m1*qp)) and dq(t,v)>=m2*qp:\n", " step=t # we found a good step\n", " break\n", " if ((qt-q0)/t>(m1*qp)):\n", " # step too big\n", " tr = t\n", " if ((qt-q0)/t<=(m1*qp)) and dq(t,v)