{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Modules in Pyro\n", "\n", "\n", "This tutorial introduces [PyroModule](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroModule), Pyro's Bayesian extension of PyTorch's [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) class. Before starting you should understand the basics of Pyro [models and inference](http://pyro.ai/examples/intro_long.html), understand the two primitives [pyro.sample()](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.sample) and [pyro.param()](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.param), and understand the basics of Pyro's [effect handlers](http://pyro.ai/examples/effect_handlers.html) (e.g. by browsing [minipyro.py](https://github.com/pyro-ppl/pyro/blob/dev/pyro/contrib/minipyro.py)).\n", "\n", "#### Summary:\n", "\n", "- [PyroModule](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroModule)s are like [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)s but allow Pyro effects for sampling and constraints.\n", "- `PyroModule` is a mixin subclass of `nn.Module` that overrides attribute access (e.g. `.__getattr__()`).\n", "- There are three different ways to create a `PyroModule`:\n", " - create a new subclass: `class MyModule(PyroModule): ...`,\n", " - Pyro-ize an existing class: `MyModule = PyroModule[OtherModule]`, or\n", " - Pyro-ize an existing `nn.Module` instance in-place: `to_pyro_module_(my_module)`.\n", "- Usual `nn.Parameter` attributes of a `PyroModule` become Pyro parameters.\n", "- Parameters of a `PyroModule` synchronize with Pyro's global param store.\n", "- You can add constrained parameters by creating [PyroParam](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroParam) objects.\n", "- You can add stochastic attributes by creating [PyroSample](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroSample) objects.\n", "- Parameters and stochastic attributes are named automatically (no string required).\n", "- `PyroSample` attributes are sampled once per `.__call__()` of the outermost `PyroModule`.\n", "- To enable Pyro effects on methods other than `.__call__()`, decorate them with @[pyro_method](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.pyro_method).\n", "- A `PyroModule` model may contain `nn.Module` attributes.\n", "- An `nn.Module` model may contain at most one `PyroModule` attribute (see [naming section](#Caution-avoiding-duplicate-names)).\n", "- An `nn.Module` may contain both a `PyroModule` model and `PyroModule` guide (e.g. [Predictive](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.predictive.Predictive)).\n", "\n", "#### Table of Contents\n", "\n", "- [How PyroModule works](#How-PyroModule-works)\n", "- [How to create a PyroModule](#How-to-create-a-PyroModule)\n", "- [How effects work](#How-effects-work)\n", "- [How to constrain parameters](#How-to-constrain-parameters)\n", "- [How to make a PyroModule Bayesian](#How-to-make-a-PyroModule-Bayesian)\n", "- [Caution: accessing attributes inside plates](#⚠-Caution:-accessing-attributes-inside-plates)\n", "- [How to create a complex nested PyroModule](#How-to-create-a-complex-nested-PyroModule)\n", "- [How naming works](#How-naming-works)\n", "- [Caution: avoiding duplicate names](#⚠-Caution:-avoiding-duplicate-names)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import pyro\n", "import pyro.distributions as dist\n", "import pyro.poutine as poutine\n", "from torch.distributions import constraints\n", "from pyro.nn import PyroModule, PyroParam, PyroSample\n", "from pyro.nn.module import to_pyro_module_\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.infer.autoguide import AutoNormal\n", "from pyro.optim import Adam\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.9.0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How `PyroModule` works \n", "\n", "[PyroModule](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroModule) aims to combine Pyro's primitives and effect handlers with PyTorch's [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) idiom, thereby enabling Bayesian treatment of existing `nn.Module`s and enabling model serving via [jit.trace_module](https://pytorch.org/docs/stable/jit.html#torch.jit.trace_module). Before you start using `PyroModule`s it will help to understand how they work, so you can avoid pitfalls.\n", "\n", "`PyroModule` is a subclass of `nn.Module`. `PyroModule` enables Pyro effects by inserting effect handling logic on module attribute access, overriding the `.__getattr__()`, `.__setattr__()`, and `.__delattr__()` methods. Additionally, because some effects (like sampling) apply only once per model invocation, `PyroModule` overrides the `.__call__()` method to ensure samples are generated at most once per `.__call__()` invocation (note `nn.Module` subclasses typically implement a `.forward()` method that is called by `.__call__()`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to create a `PyroModule` \n", "\n", "There are three ways to create a `PyroModule`. Let's start with a `nn.Module` that is not a `PyroModule`:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Linear(nn.Module):\n", " def __init__(self, in_size, out_size):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.randn(in_size, out_size))\n", " self.bias = nn.Parameter(torch.randn(out_size))\n", " \n", " def forward(self, input_):\n", " return self.bias + input_ @ self.weight\n", " \n", "linear = Linear(5, 2)\n", "assert isinstance(linear, nn.Module)\n", "assert not isinstance(linear, PyroModule)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first way to create a `PyroModule` is to create a subclass of `PyroModule`. You can update any `nn.Module` you've written to be a PyroModule, e.g.\n", "```diff\n", "- class Linear(nn.Module):\n", "+ class Linear(PyroModule):\n", " def __init__(self, in_size, out_size):\n", " super().__init__()\n", " self.weight = ...\n", " self.bias = ...\n", " ...\n", "```\n", "Alternatively if you want to use third-party code like the `Linear` above you can subclass it, using `PyroModule` as a mixin class" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class PyroLinear(Linear, PyroModule):\n", " pass\n", "\n", "linear = PyroLinear(5, 2)\n", "assert isinstance(linear, nn.Module)\n", "assert isinstance(linear, Linear)\n", "assert isinstance(linear, PyroModule)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The second way to create a `PyroModule` is to use bracket syntax `PyroModule[-]` to automatically denote a trivial mixin class as above.\n", "```diff\n", "- linear = Linear(5, 2)\n", "+ linear = PyroModule[Linear](5, 2)\n", "```\n", "In our case we can write" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "linear = PyroModule[Linear](5, 2)\n", "assert isinstance(linear, nn.Module)\n", "assert isinstance(linear, Linear)\n", "assert isinstance(linear, PyroModule)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The one difference between manual subclassing and using `PyroModule[-]` is that `PyroModule[-]` also ensures all `nn.Module` superclasses also become `PyroModule`s, which is important for class hierarchies in library code. For example since `nn.GRU` is a subclass of `nn.RNN`, also `PyroModule[nn.GRU]` will be a subclass of `PyroModule[nn.RNN]`.\n", "\n", "The third way to create a `PyroModule` is to change the type of an existing `nn.Module` instance in-place using [to_pyro_module_()](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.to_pyro_module_). This is useful if you're using a third-party module factory helper or updating an existing script, e.g." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "linear = Linear(5, 2)\n", "assert isinstance(linear, nn.Module)\n", "assert not isinstance(linear, PyroModule)\n", "\n", "to_pyro_module_(linear) # this operates in-place\n", "assert isinstance(linear, nn.Module)\n", "assert isinstance(linear, Linear)\n", "assert isinstance(linear, PyroModule)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How effects work \n", "\n", "So far we've created `PyroModule`s but haven't made use of Pyro effects. But already the `nn.Parameter` attributes of our `PyroModule`s act like [pyro.param](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.param) statements: they synchronize with Pyro's param store, and they can be recorded in traces." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear\n", "[]\n", "[]\n", "PyroLinear\n", "['bias', 'weight']\n", "['bias', 'weight']\n" ] } ], "source": [ "pyro.clear_param_store()\n", "\n", "# This is not traced:\n", "linear = Linear(5, 2)\n", "with poutine.trace() as tr:\n", " linear(example_input)\n", "print(type(linear).__name__)\n", "print(list(tr.trace.nodes.keys()))\n", "print(list(pyro.get_param_store().keys()))\n", "\n", "# Now this is traced:\n", "to_pyro_module_(linear)\n", "with poutine.trace() as tr:\n", " linear(example_input)\n", "print(type(linear).__name__)\n", "print(list(tr.trace.nodes.keys()))\n", "print(list(pyro.get_param_store().keys()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to constrain parameters \n", "\n", "Pyro parameters allow constraints, and often we want our `nn.Module` parameters to obey constraints. You can constrain a `PyroModule`'s parameters by replacing `nn.Parameter` with a [PyroParam](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroParam) attribute. For example to ensure the `.bias` attribute is positive, we can set it to" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params before: ['weight', 'bias']\n", "params after: ['weight', 'bias_unconstrained']\n", "bias: tensor([0.9777, 0.8773], grad_fn=)\n" ] } ], "source": [ "print(\"params before:\", [name for name, _ in linear.named_parameters()])\n", "\n", "linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)\n", "print(\"params after:\", [name for name, _ in linear.named_parameters()])\n", "print(\"bias:\", linear.bias)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now PyTorch will optimize the `.bias_unconstrained` parameter, and each time we access the `.bias` attribute it will read and transform the `.bias_unconstrained` parameter (similar to a Python `@property`).\n", "\n", "\n", "If you know the constraint beforehand, you can build it into the module constructor, e.g.\n", "```diff\n", " class Linear(PyroModule):\n", " def __init__(self, in_size, out_size):\n", " super().__init__()\n", " self.weight = ...\n", "- self.bias = nn.Parameter(torch.randn(out_size))\n", "+ self.bias = PyroParam(torch.randn(out_size).exp(),\n", "+ constraint=constraints.positive)\n", " ...\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to make a `PyroModule` Bayesian \n", "\n", "So far our `Linear` module is still deterministic. To make it randomized and Bayesian, we'll replace `nn.Parameter` and `PyroParam` attributes with [PyroSample](http://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroSample) attributes, specifying a prior. Let's put a simple prior over the weights, taking care to expand its shape to `[5,2]` and declare event dimensions with [.to_event()](http://docs.pyro.ai/en/stable/distributions.html#pyro.distributions.torch_distribution.TorchDistributionMixin.to_event) (as explained in the [tensor shapes tutorial](https://pyro.ai/examples/tensor_shapes.html))." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params before: ['weight', 'bias_unconstrained']\n", "params after: ['bias_unconstrained']\n", "weight: tensor([[-0.8668, -0.0150],\n", " [ 3.4642, 1.9076],\n", " [ 0.4717, 1.0565],\n", " [-1.2032, 1.0821],\n", " [-0.1712, 0.4711]])\n", "weight: tensor([[-1.2577, -0.5242],\n", " [-0.7785, -1.0806],\n", " [ 0.6239, -0.4884],\n", " [-0.2580, -1.2288],\n", " [-0.7540, -1.9375]])\n" ] } ], "source": [ "print(\"params before:\", [name for name, _ in linear.named_parameters()])\n", "\n", "linear.weight = PyroSample(dist.Normal(0, 1).expand([5, 2]).to_event(2))\n", "print(\"params after:\", [name for name, _ in linear.named_parameters()])\n", "print(\"weight:\", linear.weight)\n", "print(\"weight:\", linear.weight)\n", "\n", "example_input = torch.randn(100, 5)\n", "example_output = linear(example_input)\n", "assert example_output.shape == (100, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the `.weight` parameter now disappears, and each time we call `linear()` a new weight is sampled from the prior. In fact, the weight is sampled when the `Linear.forward()` accesses the `.weight` attribute: this attribute now has the special behavior of sampling from the prior.\n", "\n", "We can see all the Pyro effects that appear in the trace:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param bias tensor([0.9777, 0.8773], grad_fn=)\n", "sample weight tensor([[ 1.8043, 1.5494],\n", " [ 0.0128, 1.4100],\n", " [-0.2155, 0.6375],\n", " [ 1.1202, 1.9672],\n", " [-0.1576, -0.6957]])\n" ] } ], "source": [ "with poutine.trace() as tr:\n", " linear(example_input)\n", "for site in tr.trace.nodes.values():\n", " print(site[\"type\"], site[\"name\"], site[\"value\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we've modified a third-party module to be Bayesian\n", "```py\n", "linear = Linear(...)\n", "to_pyro_module_(linear)\n", "linear.bias = PyroParam(...)\n", "linear.weight = PyroSample(...)\n", "```\n", "If you are creating a model from scratch, you could instead define a new class" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class BayesianLinear(PyroModule):\n", " def __init__(self, in_size, out_size):\n", " super().__init__()\n", " self.bias = PyroSample(\n", " prior=dist.LogNormal(0, 1).expand([out_size]).to_event(1))\n", " self.weight = PyroSample(\n", " prior=dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))\n", "\n", " def forward(self, input):\n", " return self.bias + input @ self.weight # this line samples bias and weight" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that samples are drawn at most once per `.__call__()` invocation, for example\n", "```py\n", "class BayesianLinear(PyroModule):\n", " ...\n", " def forward(self, input):\n", " weight1 = self.weight # Draws a sample.\n", " weight2 = self.weight # Reads previous sample.\n", " assert weight2 is weight1 # All accesses should agree.\n", " ...\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ⚠ Caution: accessing attributes inside plates \n", "\n", "Because `PyroSample` and `PyroParam` attributes are modified by Pyro effects, we need to take care where parameters are accessed. For example [pyro.plate](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.plate) contexts can change the shape of sample and param sites. Consider a model with one latent variable and a batched observation statement. We see that the only difference between these two models is where the `.loc` attribute is accessed." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class NormalModel(PyroModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.loc = PyroSample(dist.Normal(0, 1))\n", "\n", "class GlobalModel(NormalModel):\n", " def forward(self, data):\n", " # If .loc is accessed (for the first time) outside the plate,\n", " # then it will have empty shape ().\n", " loc = self.loc\n", " assert loc.shape == ()\n", " with pyro.plate(\"data\", len(data)):\n", " pyro.sample(\"obs\", dist.Normal(loc, 1), obs=data)\n", " \n", "class LocalModel(NormalModel):\n", " def forward(self, data):\n", " with pyro.plate(\"data\", len(data)):\n", " # If .loc is accessed (for the first time) inside the plate,\n", " # then it will be expanded by the plate to shape (plate.size,).\n", " loc = self.loc\n", " assert loc.shape == (len(data),)\n", " pyro.sample(\"obs\", dist.Normal(loc, 1), obs=data)\n", "\n", "data = torch.randn(10)\n", "LocalModel()(data)\n", "GlobalModel()(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to create a complex nested `PyroModule` \n", "\n", "To perform inference with the above `BayesianLinear` module we'll need to wrap it in probabilistic model with a likelihood; that wrapper will also be a `PyroModule`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class Model(PyroModule):\n", " def __init__(self, in_size, out_size):\n", " super().__init__()\n", " self.linear = BayesianLinear(in_size, out_size) # this is a PyroModule\n", " self.obs_scale = PyroSample(dist.LogNormal(0, 1))\n", "\n", " def forward(self, input, output=None):\n", " obs_loc = self.linear(input) # this samples linear.bias and linear.weight\n", " obs_scale = self.obs_scale # this samples self.obs_scale\n", " with pyro.plate(\"instances\", len(input)):\n", " return pyro.sample(\"obs\", dist.Normal(obs_loc, obs_scale).to_event(1),\n", " obs=output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Whereas a usual `nn.Module` can be trained with a simple PyTorch optimizer, a Pyro model requires probabilistic inference, e.g. using [SVI](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.svi.SVI) and an [AutoNormal](http://docs.pyro.ai/en/stable/infer.autoguide.html#pyro.infer.autoguide.AutoNormal) guide. See the [bayesian regression tutorial](http://pyro.ai/examples/bayesian_regression.html) for details." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0 loss = 7.186\n", "step 100 loss = 2.185\n", "step 200 loss = 1.87\n", "step 300 loss = 1.739\n", "step 400 loss = 1.691\n", "step 500 loss = 1.673\n", "CPU times: user 2.35 s, sys: 24.8 ms, total: 2.38 s\n", "Wall time: 2.39 s\n" ] } ], "source": [ "%%time\n", "pyro.clear_param_store()\n", "pyro.set_rng_seed(1)\n", "\n", "model = Model(5, 2)\n", "x = torch.randn(100, 5)\n", "y = model(x)\n", "\n", "guide = AutoNormal(model)\n", "svi = SVI(model, guide, Adam({\"lr\": 0.01}), Trace_ELBO())\n", "for step in range(2 if smoke_test else 501):\n", " loss = svi.step(x, y) / y.numel()\n", " if step % 100 == 0:\n", " print(\"step {} loss = {:0.4g}\".format(step, loss))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`PyroSample` statements may also depend on other sample statements or parameters. In this case the `prior` can be a callable depending on `self`, rather than a constant distribution. For example consider the hierarchical model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5387)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Model(PyroModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.dof = PyroSample(dist.Gamma(3, 1))\n", " self.loc = PyroSample(dist.Normal(0, 1))\n", " self.scale = PyroSample(lambda self: dist.InverseGamma(self.dof, 1))\n", " self.x = PyroSample(lambda self: dist.Normal(self.loc, self.scale))\n", " \n", " def forward(self):\n", " return self.x\n", " \n", "Model()()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How naming works \n", "\n", "In the above code we saw a `BayesianLinear` model embedded inside another `Model`. Both were `PyroModule`s. Whereas simple [pyro.sample](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.sample) statements require name strings, `PyroModule` attributes handle naming automatically. Let's see how that works with the above `model` and `guide` (since `AutoNormal` is also a `PyroModule`).\n", "\n", "Let's trace executions of the model and the guide." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sample linear.bias torch.Size([2])\n", "sample linear.weight torch.Size([5, 2])\n", "sample obs_scale torch.Size([])\n", "sample instances torch.Size([100])\n", "sample obs torch.Size([100, 2])\n" ] } ], "source": [ "with poutine.trace() as tr:\n", " model(x)\n", "for site in tr.trace.nodes.values():\n", " print(site[\"type\"], site[\"name\"], site[\"value\"].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Observe that `model.linear.bias` corresponds to the `linear.bias` name, and similarly for the `model.linear.weight` and `model.obs_scale` attributes. The \"instances\" site corresponds to the plate, and the \"obs\" site corresponds to the likelihood. Next examine the guide:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param AutoNormal.locs.linear.bias torch.Size([2])\n", "param AutoNormal.scales.linear.bias torch.Size([2])\n", "sample linear.bias_unconstrained torch.Size([2])\n", "sample linear.bias torch.Size([2])\n", "param AutoNormal.locs.linear.weight torch.Size([5, 2])\n", "param AutoNormal.scales.linear.weight torch.Size([5, 2])\n", "sample linear.weight_unconstrained torch.Size([5, 2])\n", "sample linear.weight torch.Size([5, 2])\n", "param AutoNormal.locs.obs_scale torch.Size([])\n", "param AutoNormal.scales.obs_scale torch.Size([])\n", "sample obs_scale_unconstrained torch.Size([])\n", "sample obs_scale torch.Size([])\n" ] } ], "source": [ "with poutine.trace() as tr:\n", " guide(x)\n", "for site in tr.trace.nodes.values():\n", " print(site[\"type\"], site[\"name\"], site[\"value\"].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see the guide learns posteriors over three random variables: `linear.bias`, `linear.weight`, and `obs_scale`. For each of these, the guide learns a `(loc,scale)` pair of parameters, which are stored internally in nested `PyroModules`:\n", "```python\n", "class AutoNormal(...):\n", " def __init__(self, ...):\n", " self.locs = PyroModule()\n", " self.scales = PyroModule()\n", " ...\n", "```\n", "Finally, `AutoNormal` contains a `pyro.sample` statement for each unconstrained latent site followed by a [pyro.deterministic](http://docs.pyro.ai/en/stable/primitives.html#pyro.primitives.deterministic) statement to map the unconstrained sample to a constrained posterior sample." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ⚠ Caution: avoiding duplicate names \n", "\n", "`PyroModule`s name their attributes automatically, event for attributes nested deeply in other `PyroModule`s. However care must be taken when mixing usual `nn.Module`s with `PyroModule`s, because `nn.Module`s do not support automatic site naming.\n", "\n", "Within a single model (or guide):\n", "\n", "If there is only a single `PyroModule`, then your are safe.\n", "```diff\n", " class Model(nn.Module): # not a PyroModule\n", " def __init__(self):\n", " self.x = PyroModule()\n", "- self.y = PyroModule() # Could lead to name conflict.\n", "+ self.y = nn.Module() # Has no Pyro names, so avoids conflict.\n", "```\n", "If there are only two `PyroModule`s then one must be an attribute of the other.\n", "```diff\n", "class Model(PyroModule):\n", " def __init__(self):\n", " self.x = PyroModule() # ok\n", "```\n", "If you have two `PyroModule`s that are not attributes of each other, then they must be connected by attribute links through other `PyroModule`s. These can be sibling links\n", "```diff\n", "- class Model(nn.Module): # Could lead to name conflict.\n", "+ class Model(PyroModule): # Ensures names are unique.\n", " def __init__(self):\n", " self.x = PyroModule()\n", " self.y = PyroModule()\n", "```\n", "or ancestor links\n", "```diff\n", " class Model(PyroModule):\n", " def __init__(self):\n", "- self.x = nn.Module() # Could lead to name conflict.\n", "+ self.x = PyroModule() # Ensures y is conected to root Model.\n", " self.x.y = PyroModule()\n", "```\n", "\n", "Sometimes you may want to store a `(model,guide)` pair in a single `nn.Module`, e.g. to serve them from C++. In this case it is safe to make them attributes of a container `nn.Module`, but that container should *not* be a `PyroModule`.\n", "```python\n", "class Container(nn.Module): # This cannot be a PyroModule.\n", " def __init__(self, model, guide): # These may be PyroModules.\n", " super().__init__()\n", " self.model = model\n", " self.guide = guide\n", " # This is a typical trace-replay pattern seen in model serving.\n", " def forward(self, data):\n", " tr = poutine.trace(self.guide).get_trace(data)\n", " return poutine.replay(model, tr)(data)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }