{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Understanding Hyperbole using RSA\n", "\n", " \"My new kettle cost a million dollars.\"\n", "\n", "Hyperbole -- using an exagerated utterance to convey strong opinions -- is a common non-literal use of language. Yet non-literal uses of langauge are impossible under the simplest RSA model. Kao, et al, suggested that two ingredients could be added to ennable RSA to capture hyperbole. First, the state conveyed by the speaker and reasoned about by the listener should include affective dimensions. Second, the speaker only intends to convey information relevant to a particular topic, such as \"how expensive was it?\" or \"how am I feeling about the price?\"; pragmatic listeners hence jointly reason about this topic and the state." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#first some imports\n", "import torch\n", "torch.set_default_dtype(torch.float64) # double precision for numerical stability\n", "\n", "import collections\n", "import argparse\n", "import matplotlib.pyplot as plt\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.poutine as poutine\n", "\n", "from search_inference import HashingMarginal, memoize, Search" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As in the simple RSA example, the inferece helper Marginal takes an un-normalized stochastic function, constructs the distribution over execution traces by using Search, and constructs the marginal distribution on return values (via HashingMarginal)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def Marginal(fn):\n", " return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The domain for this example will be states consisting of price (e.g. of a tea kettle) and the speaker's emotional arousal (whether the speaker thinks this price is irritatingly expensive). Priors here are adapted from experimental data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "State = collections.namedtuple(\"State\", [\"price\", \"arousal\"])\n", "\n", "def price_prior():\n", " values = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001]\n", " probs = torch.tensor([0.4205, 0.3865, 0.0533, 0.0538, 0.0223, 0.0211, 0.0112, 0.0111, 0.0083, 0.0120])\n", " ix = pyro.sample(\"price\", dist.Categorical(probs=probs))\n", " return values[ix]\n", "\n", "def arousal_prior(price):\n", " probs = {\n", " 50: 0.3173,\n", " 51: 0.3173,\n", " 500: 0.7920,\n", " 501: 0.7920,\n", " 1000: 0.8933,\n", " 1001: 0.8933,\n", " 5000: 0.9524,\n", " 5001: 0.9524,\n", " 10000: 0.9864,\n", " 10001: 0.9864\n", " }\n", " return pyro.sample(\"arousal\", dist.Bernoulli(probs=probs[price])).item() == 1\n", "\n", "def state_prior():\n", " price = price_prior()\n", " state = State(price=price, arousal=arousal_prior(price))\n", " return state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we define a version of the RSA speaker that only produces *relevant* information for the literal listener. We define relevance with respect to a Question Under Discussion (QUD) -- this can be thought of as defining the speaker's current attention or topic.\n", "\n", "The speaker is defined mathematically by:\n", "\n", "$$P_S(u|s,q) \\propto \\left[ \\sum_{w'} \\delta_{q(w')=q(w)} P_\\text{Lit}(w'|u) p(u) \\right]^\\alpha$$\n", "\n", "To implement this as a probabilistic program, we start with a helper function project, which takes a distribution over some (discrete) domain and a function qud on this domain. It creates the push-forward distribution, using Marginal (as a Python decorator). The speaker's relevant information is then simply information about the state in this projection." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "@Marginal\n", "def project(dist,qud):\n", " v = pyro.sample(\"proj\",dist)\n", " return qud_fns[qud](v)\n", "\n", "@Marginal\n", "def literal_listener(utterance):\n", " state=state_prior()\n", " pyro.factor(\"literal_meaning\", 0. if meaning(utterance, state.price) else -999999.)\n", " return state\n", "\n", "@Marginal\n", "def speaker(state, qud):\n", " alpha = 1.\n", " qudValue = qud_fns[qud](state)\n", " with poutine.scale(scale=torch.tensor(alpha)):\n", " utterance = utterance_prior()\n", " literal_marginal = literal_listener(utterance)\n", " projected_literal = project(literal_marginal, qud)\n", " pyro.sample(\"listener\", projected_literal, obs=qudValue)\n", " return utterance\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The possible QUDs capture that the speaker may be attending to the price, her affect, or some combination of these. We assume a uniform QUD prior." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#The QUD functions we consider:\n", "qud_fns = {\n", " \"price\": lambda state: State(price=state.price, arousal=None),\n", " \"arousal\": lambda state: State(price=None, arousal=state.arousal),\n", " \"priceArousal\": lambda state: State(price=state.price, arousal=state.arousal),\n", "}\n", "\n", "def qud_prior():\n", " values = list(qud_fns.keys())\n", " ix = pyro.sample(\"qud\", dist.Categorical(probs=torch.ones(len(values)) / len(values)))\n", " return values[ix]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we specify the utterance meanings (standard number word denotations: \"N\" means exactly $N$) and a uniform utterance prior. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def utterance_prior():\n", " utterances = [50, 51, 500, 501, 1000, 1001, 5000, 5001, 10000, 10001]\n", " ix = pyro.sample(\"utterance\", dist.Categorical(probs=torch.ones(len(utterances)) / len(utterances)))\n", " return utterances[ix]\n", "\n", "def meaning(utterance, price):\n", " return utterance == price" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "OK, let's see what number term this speaker will say to express different states and QUDs." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "