{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The Rational Speech Act framework\n", "Human language depends on the assumption of *cooperativity*, that speakers attempt to provide relevant information to the listener; listeners can use this assumption to reason *pragmatically* about the likely state of the world given the utterance chosen by the speaker.\n", "\n", "The Rational Speech Act framework formalizes these ideas using probabiistic decision making and reasoning.\n", "\n", "Note: This notebook must be run against Pyro 4392d54a220c328ee356600fb69f82166330d3d6 or later." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#first some imports\n", "import torch\n", "torch.set_default_dtype(torch.float64) # double precision for numerical stability\n", "\n", "import collections\n", "import argparse\n", "import matplotlib.pyplot as plt\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.poutine as poutine\n", "\n", "from search_inference import HashingMarginal, memoize, Search" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we can defined RSA, we specify a helper function that wraps up inference. Marginal takes an un-normalized stochastic function, constructs the distribution over execution traces by using Search, and constructs the marginal distribution on return values (via HashingMarginal)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def Marginal(fn):\n", " return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The RSA model captures recursive social reasoning -- a listener thinks about a speaker who thinks about a listener....\n", "\n", "To start, the literal_listener simply imposes that the utterance is true. Mathematically:\n", "$$P_\\text{Lit}(s|u) \\propto {\\mathcal L}(u,s)P(s)$$\n", "\n", "In code:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@Marginal\n", "def literal_listener(utterance):\n", " state = state_prior()\n", " pyro.factor(\"literal_meaning\", 0. if meaning(utterance, state) else -999999.)\n", " return state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next the cooperative speaker chooses an utterance to convey a given state to the literal listener. Mathematically:\n", "\n", "$$P_S(u|s) \\propto [P_\\text{Lit}(s|u) P(u)]^\\alpha$$\n", "\n", "In the code below, the utterance_prior captures the cost of producing an utterance, while the pyro.sample expression captures that the litteral listener guesses the right state (obs=state indicates that the sampled value is observed to be the correct state).\n", "\n", "We use poutine.scale to raise the entire execution probability to the power of alpha -- this yields a softmax decision rule with optimality parameter alpha." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "@Marginal\n", "def speaker(state):\n", " alpha = 1.\n", " with poutine.scale(scale=torch.tensor(alpha)):\n", " utterance = utterance_prior()\n", " pyro.sample(\"listener\", literal_listener(utterance), obs=state)\n", " return utterance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can define the pragmatic_listener, who infers which state is likely, given that the speaker chose a given utterance. Mathematically:\n", "\n", "$$P_L(s|u) \\propto P_S(u|s) P(s)$$\n", "\n", "In code:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "@Marginal\n", "def pragmatic_listener(utterance):\n", " state = state_prior()\n", " pyro.sample(\"speaker\", speaker(state), obs=utterance)\n", " return state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's set up a simple world by filling in the priors. We imagine there are 4 objects each either blue or red, and the possible utterances are \"none are blue\", \"some are blue\", \"all are blue\".\n", "\n", "We take the prior probabilities for the number of blue objects and the utterance to be uniform." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "total_number = 4\n", "\n", "def state_prior():\n", " n = pyro.sample(\"state\", dist.Categorical(probs=torch.ones(total_number+1) / total_number+1))\n", " return n\n", "\n", "def utterance_prior():\n", " ix = pyro.sample(\"utt\", dist.Categorical(probs=torch.ones(3) / 3))\n", " return [\"none\",\"some\",\"all\"][ix]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, the meaning function (notated $\\mathcal L$ above):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "meanings = {\n", " \"none\": lambda N: N==0,\n", " \"some\": lambda N: N>0,\n", " \"all\": lambda N: N==total_number,\n", "}\n", "\n", "def meaning(utterance, state):\n", " return meanings[utterance](state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see if it works: how does the pragmatic listener interpret the \"some\" utterance?" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEZCAYAAAB7HPUdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFvNJREFUeJzt3X20XXV95/H3h8RgxdrycB01PCRIdEyXrdaI7VjtLEEIZYb4AGNwtcUlU6Yd6dTFqi2tFm2oawB1pqstWJgxLWNrQWCtmVTC4APYtWY6aGKlarApMSAk2jEVppYHCYHv/LH31cPtPdxzk/v8e7/WOitn//Zvn/zu95zzOfvsvc/eqSokSW04bL4HIEmaO4a+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSHL53sAEx1zzDG1atWq+R6GJC0qX/jCF/6+qsam6rfgQn/VqlVs3759vochSYtKkq+P0s/NO5LUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGLLgfZ2l2rLr45ll53HsvO3NWHncuzUZtlkJdwNosRa7pS1JDDH1JaoihL0kNMfQlqSEjhX6S9Ul2JtmV5OJJ5v9iki8nuTPJ/0qydmDeb/TL7Uxy+kwOXpI0PVOGfpJlwJXAGcBa4NzBUO99rKpeWlUvA64A/lO/7FpgI/AjwHrgqv7xJEnzYJQ1/ZOBXVW1u6r2A9cBGwY7VNV3BiaPAKq/vwG4rqoeq6p7gF3940mS5sEox+mvBO4fmN4DvGpipyTvAC4CVgCvG1j2jgnLrpxk2QuACwCOP/74UcYtSToIM7Yjt6qurKoXAr8OvGeay15TVeuqat3Y2JRX+5IkHaRRQn8vcNzA9LF92zDXAW84yGUlSbNolNDfBqxJsjrJCrods1sGOyRZMzB5JnB3f38LsDHJ4UlWA2uAzx/6sCVJB2PKbfpVdSDJhcCtwDJgc1XtSLIJ2F5VW4ALk5wKPA48CJzXL7sjyceBu4ADwDuq6olZ+lskSVMY6YRrVbUV2Dqh7ZKB+7/yNMu+H3j/wQ5QkjRz/EWuJDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpISOFfpL1SXYm2ZXk4knmX5TkriRfSvKZJCcMzHsiyZ39bctMDl6SND3Lp+qQZBlwJfB6YA+wLcmWqrproNsXgXVV9UiSXwKuAN7Sz3u0ql42w+OWJB2EUdb0TwZ2VdXuqtoPXAdsGOxQVbdX1SP95B3AsTM7TEnSTBgl9FcC9w9M7+nbhjkfuGVg+plJtie5I8kbDmKMkqQZMuXmnelI8rPAOuCnB5pPqKq9SU4Ebkvy5ar62oTlLgAuADj++ONnckiSpAGjrOnvBY4bmD62b3uKJKcC7wbOqqrHxturam//727gs8DLJy5bVddU1bqqWjc2NjatP0CSNLpRQn8bsCbJ6iQrgI3AU47CSfJy4Gq6wP/WQPuRSQ7v7x8DvBoY3AEsSZpDU27eqaoDSS4EbgWWAZurakeSTcD2qtoCfAB4NnBDEoD7quos4CXA1UmepPuAuWzCUT+SpDk00jb9qtoKbJ3QdsnA/VOHLPeXwEsPZYCSpJnjL3IlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSEjhX6S9Ul2JtmV5OJJ5l+U5K4kX0rymSQnDMw7L8nd/e28mRy8JGl6pgz9JMuAK4EzgLXAuUnWTuj2RWBdVf0ocCNwRb/sUcB7gVcBJwPvTXLkzA1fkjQdo6zpnwzsqqrdVbUfuA7YMNihqm6vqkf6yTuAY/v7pwOfqqoHqupB4FPA+pkZuiRpukYJ/ZXA/QPTe/q2Yc4HbpnOskkuSLI9yfZ9+/aNMCRJ0sGY0R25SX4WWAd8YDrLVdU1VbWuqtaNjY3N5JAkSQNGCf29wHED08f2bU+R5FTg3cBZVfXYdJaVJM2NUUJ/G7AmyeokK4CNwJbBDkleDlxNF/jfGph1K3BakiP7Hbin9W2SpHmwfKoOVXUgyYV0Yb0M2FxVO5JsArZX1Ra6zTnPBm5IAnBfVZ1VVQ8kuZTugwNgU1U9MCt/iSRpSlOGPkBVbQW2Tmi7ZOD+qU+z7GZg88EOUJI0c/xFriQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSEjhX6S9Ul2JtmV5OJJ5r82yV8lOZDk7AnznkhyZ3/bMlMDlyRN3/KpOiRZBlwJvB7YA2xLsqWq7hrodh/wNuBXJ3mIR6vqZTMwVknSIZoy9IGTgV1VtRsgyXXABuB7oV9V9/bznpyFMUqSZsgom3dWAvcPTO/p20b1zCTbk9yR5A2TdUhyQd9n+759+6bx0JKk6ZiLHbknVNU64K3A7yZ54cQOVXVNVa2rqnVjY2NzMCRJatMoob8XOG5g+ti+bSRVtbf/dzfwWeDl0xifJGkGjRL624A1SVYnWQFsBEY6CifJkUkO7+8fA7yagX0BkqS5NWXoV9UB4ELgVuCrwMerakeSTUnOAkjyyiR7gHOAq5Ps6Bd/CbA9yV8DtwOXTTjqR5I0h0Y5eoeq2gpsndB2ycD9bXSbfSYu95fASw9xjJKkGeIvciWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0ZKfSTrE+yM8muJBdPMv+1Sf4qyYEkZ0+Yd16Su/vbeTM1cEnS9E0Z+kmWAVcCZwBrgXOTrJ3Q7T7gbcDHJix7FPBe4FXAycB7kxx56MOWJB2MUdb0TwZ2VdXuqtoPXAdsGOxQVfdW1ZeAJycsezrwqap6oKoeBD4FrJ+BcUuSDsIoob8SuH9gek/fNoqRlk1yQZLtSbbv27dvxIeWJE3XgtiRW1XXVNW6qlo3NjY238ORpCVrlNDfCxw3MH1s3zaKQ1lWkjTDRgn9bcCaJKuTrAA2AltGfPxbgdOSHNnvwD2tb5MkzYMpQ7+qDgAX0oX1V4GPV9WOJJuSnAWQ5JVJ9gDnAFcn2dEv+wBwKd0HxzZgU98mSZoHy0fpVFVbga0T2i4ZuL+NbtPNZMtuBjYfwhglSTNkQezIlSTNDUNfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDRrpGriTpqVZdfPOMP+a9l5054485kWv6ktQQQ1+SGmLoS1JDDH1JashIoZ9kfZKdSXYluXiS+Ycnub6f/7kkq/r2VUkeTXJnf/vDmR2+JGk6pjx6J8ky4Erg9cAeYFuSLVV110C384EHq+qkJBuBy4G39PO+VlUvm+FxS5IOwihr+icDu6pqd1XtB64DNkzoswG4tr9/I3BKkszcMCVJM2GU0F8J3D8wvadvm7RPVR0A/gE4up+3OskXk/xFktdM9h8kuSDJ9iTb9+3bN60/QJI0utnekftN4PiqejlwEfCxJM+Z2KmqrqmqdVW1bmxsbJaHJEntGiX09wLHDUwf27dN2ifJcuCHgG9X1WNV9W2AqvoC8DXgRYc6aEnSwRkl9LcBa5KsTrIC2AhsmdBnC3Bef/9s4LaqqiRj/Y5gkpwIrAF2z8zQJUnTNeXRO1V1IMmFwK3AMmBzVe1IsgnYXlVbgI8AH02yC3iA7oMB4LXApiSPA08Cv1hVD8zGHyJJmtpIJ1yrqq3A1gltlwzc/y5wziTL3QTcdIhjlCTNEH+RK0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JashIoZ9kfZKdSXYluXiS+Ycnub6f/7kkqwbm/UbfvjPJ6TM3dEnSdE0Z+kmWAVcCZwBrgXOTrJ3Q7Xzgwao6CfjPwOX9smuBjcCPAOuBq/rHkyTNg1HW9E8GdlXV7qraD1wHbJjQZwNwbX//RuCUJOnbr6uqx6rqHmBX/3iSpHmwfIQ+K4H7B6b3AK8a1qeqDiT5B+Dovv2OCcuunPgfJLkAuKCffCjJzpFGP3eOAf5+vgexEOVyazMZ6zKctRnuEGtzwiidRgn9WVdV1wDXzPc4hkmyvarWzfc4FiJrMznrMpy1GW4uajPK5p29wHED08f2bZP2SbIc+CHg2yMuK0maI6OE/jZgTZLVSVbQ7ZjdMqHPFuC8/v7ZwG1VVX37xv7ontXAGuDzMzN0SdJ0Tbl5p99GfyFwK7AM2FxVO5JsArZX1RbgI8BHk+wCHqD7YKDv93HgLuAA8I6qemKW/pbZtGA3PS0A1mZy1mU4azPcrNcm3Qq5JKkF/iJXkhpi6EtSQwx9SWqIoS9JDTH0h5h4fqH+tBICkvzYhGlfRz1rM5zvqeHmsjYevTOJJP8c2AxcD3yjqm7o2w+rqifndXDzrD+D6keBW4D9VfXBvj3V+Iup/y3Kf8Pa/BO+p4ZL8hK6w97npDaG/hBJXgj8BHAqcHhVvbVv90WaPBdYDfwq8BzgjKp60tpAkhfQ/Qrd2kyQ5CS683b5nuqN/+39CsO/YA5qY+gPSPLzwMPAV6pqZ9/2LLozhx5WVev7tubW3JKs7s+UOrH9JuDZVXV6P91ibd4CPAbsqKq7B9qtTfJ24AlgW1Xd1Z9a/ZnAx4Fljb+nPgjcBvxFVT3cbwr8AWa5Nm5v7CXZTHddgFcAn0hyZpJnVdUjVfUzwKNJPgDQ4Ivzz4APJTlloO0ZAFX1ZmB/36fF2lxNd4bYfwncnuRF4/OsTa6iOy3LScBnkxxTVU9U1cNVdSZtv6d+EHgbcBrwmj5rnhyozXdnqzaGPpDkxcCqqvrpqvpN4D3AfwDOHNihcjFwWJIT52uc8yHJucCPAl+iu07CKQBV9fh48ANvBr6T5KfmaZjzIsm7gJVVdUpVvZPuJ/Tn9PNW9N1arc2vAcdX1c9U1W8BnwCel+T5A91afU8tA/YDX6Q7Bf0b6E9Xn+SH+26/zizVxtAH+k0530iyIcnyqroe+DDwLuAn+27fBI4AXjlPw5wvn6a7atp/oduEcfqE4F8GPAncC7xwvgY5T75Ct+1+3H3ASwGqan+/wlC0WZtPAm8ESPJO4K10J2X8VJLX9H2afE/133YeA/4U+ANgO/CmJDfSfQDALNam6dBP8vYk/6bflvYVYB3wAoCq+u/AHwFXJDmiqr5D9wQdM7CGu2QlOT/J2VW1D/hmVe2lO/riH4EzkoxfAe3EqjoA3ASMDazhLll9bc6i+0AcvMDQ/6EL+XHPq6rHaa82b6yqO/uVgqPpavKiqnoX8PvAtUme0+B76u1Jzhlo+mHgX1XVfwV+HHgd8K1+G/6s1abp0KfbaTu+d/xaulM//0K/uYeq+jDwdWD86IuvAB/p38hL3UN0Z1UFeKJ/Id4P/Anwd3RrJvfQfUWnqv4W+L3+kppL3UPAs/rXwSMD7UfQXWWNJDcDF0GTtVkB39sB+W3gqqq6r5//R3RX02vxPfUwMHj8/fXAw/0O3QPAB+nOUPxjfe1mpTYL4spZc23gMKj7gM1J7q2qO5JcBPwO8MtJHqf7Sv5oVT06vmxVfXd+Rj03JqnN7qrall5V3ZPkT4A7gVuq6vzxZZd6qE1Sm6+N16bv8iBwRP81/ev9mi3QZG12V9U26DYDDnT9Y+DhqnpovKHB99Q9fW0O0G3y+n9V9ZN937Oq6s7xZWejNs0dsjlwXOzhwFF0R12cCfxOVf1NkiOBtXQXcD+sqj7UL7fkDyl7mtq8r6p29ZvBCvhlYH1/VFMTx1lPVZu+zzPpNvfcWFW/NLjcPA17Toz4ujkeeDewoqrO65dr+T11aVXtTDIGPNhvIh1cbtZq01TojxeyXzO7ie5r5seAtwPPBf6w/0o1cbkW3rhT1ebKqvpq33f5+IvU2ny/Nv1heG+qqmv75azNU2vzqqr6dL+ctYGrq+rLfd85q0dT2/QHPjnfSfeBd0VV7QH+nO6r1x8k+al+jQ343hO3pF+cMFJtPtzX5vCBwLc236/Na6vqH1sKfBi5Nq/pa9NM4MNItfn9JK9O8oy5rEcT2/QHX2Tpfu78YuDEJK+rqtuq6otJdgD30B2jf3uSbf28Jf1VyNoMN83a/GaSnwC+UFWfWeqhNs3avDvJbVibyWrzWwy8p+ZkfEv8fUuSZVX1RP8Va4zuRxHfBd7X37+5qj430H8l3Q6Wx6vqgXkY8pyxNsNZm+GszXCLoTZLPvThe6e3/Z90F21fR/ejq88D/57uEKqbq+p/z98I54+1Gc7aDGdthlvotVnS2/QHDqX7PWBnVW0E/h3wIbon44N0Z0J8/uSPsHRZm+GszXDWZrjFUpsluaY//hVrYPp9wN1V9af99Gl0T8QrgOf2O1eaYG2GszbDWZvhFlttltya/sA2tcOS/Nu+eT/dOWMOA6iqTwJ/C/zA+BMw8Cm9ZFmb4azNcNZmuMVYm6W6pn8YsBX4m+rOfkiSG+jO63078Hrgoap627wNcp5Ym+GszXDWZrjFVpulGvqb6PaGX5rkKLoz1d0O/BxwJN0n7qV93yX/q8BB1mY4azOctRlusdVmSRynP7hNrf/a9CzgqCTvpyv6aXTHxK6fsO1tyf9IxNoMZ22GszbDLfbaLPpt+n0hn0hnbf8p+h66M0R+G/iPVXUS3bGyT7kgwUJ4AmaTtRnO2gxnbYZbCrVZ1Gv6eeoPIW4Fjk5yB3B9ff+kTsckuZbuTHZ3P93jLSXWZjhrM5y1GW6p1GZRr+kPPAG/DdxMdyX5e+jO9X523+1X+r4/B20cUQDW5ulYm+GszXBLpTaLfkdukjcDNwBvrKr/keQFwJvoLojyWeAT1Z/Pe6FsU5sr1mY4azOctRluKdRm0a3pp7sm6/dU1U3ApcDlSU6qqm/QPSnfBH5w4AlY8meEtDbDWZvhrM1wS7E2i2pNf/yTM91xsZfTXa91G915qn8BeDPw89VdnGD8GpxNsDbDWZvhrM1wS7U2i2ZNf+AJGN+J8nfAM4DLgBdU1RV0Fyr4ZJLnjT8BC3Gb2kyzNsNZm+GszXBLujZVteBvwPP7fwO8AfjtfvoW4J39/aP7f//1fI/X2iyMm7WxNtbmn94W/Jp+umtIvivJW4Gr6HaYrEmyE/hkVf1ukiPoLmLxz6rqz/vlFvzfdqiszXDWZjhrM1wLtVnwA62qfXQ/ab4GWFVVH6A7TOrzwKf7bn8MHFVV/3dguQW5E2UmWZvhrM1w1ma4FmqzKHbkJllFd1mxVwC/BuwAzgPW0+01/259/8cR835ui7lkbYazNsNZm+GWfG3me/vSdG7A6cBfA6/rpy8BXjkw/7D5HqO1WXg3a2NtrM33b4vqNAxVdWuS5cBHknwL+FxVbYOFfVzsXLA2w1mb4azNcEu1Noti885ESV4M/HhV/Vk/vfi+Ys0SazOctRnO2gy31GqzKEN/UBboT50XAmsznLUZztoMtxRqs+hDX5I0ugV/yKYkaeYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDfn/zqkuACNL9hMAAAAASUVORK5CYII=\n", "text/plain": [ "