all posts

Deep Learning with Integer Activations

Simon Ramstedt, 2018-02-10

A couple of weeks ago I got to present Temporally Efficient Deep Learning with Spikes by O’Connor et al, 2018 in a reading group at the lab. I liked the modular way in which it presents its method. It has little boxes like this

that describe the stateful modules that make up the algorithms. Here, I want to look in detail at the mathematical assumptions that have to be made for the method to be valid. While most (but not all) of the math here can also be found scattered throughout the paper, I am trying to present it in a more linear, proof-like manner.

Why spiking neural networks?

Spiking neural networks are interesting in two ways. 1) The brain uses spikes and we want to understand how it works. 2) Spikes are binary and therefore are cheaper to communicate and store and also have the potential to reduce the costly weight multiplication in neural networks to a cheap sum of integer weights.

In the brain, each neuron has on average 7000 connections to other neurons. Therefore it makes sense to trade computation and bandwidth in the connections for computation on the neuron level, i.e. spend computation on encoding the activations.

To recap, a typical, non-spiking neural activation is computed as

x=h(z)withz=w x=(iwij xi)jx' = h(z) \quad \text{with} \quad z = w \ x = \color{grey}{\left(∑_i w_{ij} \ x_i \right)_j}

where xx are the activations of the previous layer (or the network inputs) and hh is a non-linear function, e.g. h(z)=max(z,0)h(z) = \max(z, 0).

Usually floating point numbers are used to represent weights and activations. Floating point numbers are divided into an exponent and a mantissa. To multiply them we have to integer-add the exponent and integer-multiply the mantissa. This is implemented in hardware but it still requires a lot of chip space and energy. However, if xx were binary (i.e. a spike), we can compute zz with a sparse sum. We could even use integers to represent ww and then the multiplication would be as simple and computationally cheap as it gets.

In their paper O’Connor et al, 2018 introduce an encoding scheme that uses “integer spikes” in the forward pass, backward pass and for the weight updates without loss in accuracy compared to non-spiking networks.


One thing about the paper I found somewhat misleading is that the authors call their method spiking, when they actually use “integer spikes”. Using integers instead of binary values to communicate activations is still much cheaper than floats but it requires integer multiplications to compute the inner product with the weights. So a more appropriate name would have been “Temporally Efficient Deep Learning with Integer Activations”. Nevertheless the paper is very insightful and it would probably be possible to tweak the method in certain ways to allow it to work with only binary spikes.


Below we see the dataflow from one neuron to another neuron in a standard neural network. In the next section we will focus on the axon part, i.e., communicating the activations and trying to find a bandwidth saving encoding.

 ⁣ ⁣ ⁣ ⁣ ⁣ ⁣z h neuron a xaxon wsynapses  ⁣ ⁣ ⁣ ⁣ ⁣ ⁣zh neuron bx\cdots \underbrace{ \overset{\!\!\! z} {\!\!\!⟶} \ h \ }_\text{neuron a} \ \boxed{\underbrace{\overset {x} ⟶}_\text{axon}} \ \underbrace{w}_\text{synapses} \ \underbrace{\overset{\!\!\! z} {\!\!\!⟶} h \ }_\text{neuron b} \overset {x} ⟶ \cdots

Predictive Coding   xencadecx^\ \ x → \text{enc} → a → \text{dec} → \hat x

In predictive coding the sender and receiver share a model for the temporal evolution of the signal between them. Instead of communicating the original signal, only the model error is communicated and therefore only the model error is affected by channel noise which results in a higher signal-to-noise ratio.

Predictive coding is usually not used for neuron-to-neuron communication because the channel is not noisy (we usually use float32 to communicate the activations). Since we want to save bandwidth however, we will have to quantize the signal and therefore introduce quantization noise (see next section).

The neuron-to-neuron communication in a standard neural network without predictive coding can be framed as predictive coding with the model xt=0+atx_t = 0 \color{grey}{ + a_t} with the error at=xta_t = x_t that has to be communicated. Another very simple model would be to assume the signal stays constant, i.e. xt=xt1+atx_t = x_{t-1} \color{grey}{ + a_t}, then we would only transmit activation changes.

O’Connor et al. use a similar decaying model:

xt=kdkp+kdxt1+1kp+kdatx_t = \frac {k_d } {k_p + k_d} x_{t-1} \color{grey}{ + \frac 1{k_p + k_d} a_t}

Note that the error is ata_t scaled by the factor 1kp+kd\frac 1{k_p + k_d}. We can rewrite the model equation as an encoder-decoder pair:


enc:at=kpxt+kd(xtxt1)dec:x^t=xt=at+kdxt1kp+kd\begin{aligned} \text{enc:} & \quad a_t = k_p x_t + k_d (x_t - x_{t-1}) \\ \\ \text{dec:} & \quad \hat x_t = x_t = \frac {a_t + k_d x_{t-1}} {k_p + k_d}\end{aligned}


We also can unroll the xt1x_{t-1} in this expression (useful for later):

xt=atkp+kd+xt1kdkp+kd=atkp+kd+(at1kp+kd+xt2kdkp+kd)kdkp+kd=atkp+kd+kd at1(kp+kd)2+xt2(kdkp+kd)2=1kp+kdi=0t(kdkp+kd)tiai\begin{aligned} x_t &= \frac {a_t} {k_p + k_d} + x_{t-1} \frac {k_d } {k_p + k_d} \\ &= \frac {a_t} {k_p + k_d} + \left(\frac {a_{t-1}} {k_p + k_d} + x_{t-2} \frac {k_d } {k_p + k_d} \right) \frac {k_d } {k_p + k_d} \\ &= \frac {a_t} {k_p + k_d} + \frac {k_d \ a_{t-1}} {(k_p + k_d)^2} + x_{t-2} \left(\frac {k_d } {k_p + k_d} \right)^2\\ &= \frac 1 {k_p + k_d} \sum_{i=0}^{t} \left(\frac {k_d } {k_p + k_d} \right)^{t-i} a_i\end{aligned}

Sigma-Delta modulation   aQsQ1a^\ \ a → Q → s → Q^{-1} → \hat a

Sigma-Delta modulation is a quantization scheme and a form of noise shaping for converting high bit-count, low frequency signals into low bit-count, high frequency signals. Let’s look at how that works:

Because quantization s=round(a)s = \operatorname{round}(a) \, loses information, we store the “leftover”, ϕ=as\phi = a - s \, and add it at the next timestep s=round(ϕ+a)s = \operatorname{round}(\phi + a).


So, starting with ϕ0=0\phi_0 = 0, we have

st=round(ϕt+at)ϕt+1=(ϕt+at)st\begin{aligned} & s_t = \operatorname{round}(\phi_t + a_t) \\ & \phi_{t+1} = (\phi_{t} + a_t) - s_t\\\end{aligned}


Note: In general we round to the next integer. To ensure that we get binary spikes, i.e. s0,1s \in {0, 1} we need ϕt+at[0.5,1.5]\phi_t + a_t ∈ [0.5, 1.5] and because ϕt[0.5,0.5]ϕ_t ∈ [-0.5, 0.5] we want at[0,1]a_t ∈ [0, 1] which we can ensure by increasing the temporal resolution and tweaking kpk_p and kdk_d (see previous section).

But how can we reconstruct aa from this? To get a relation between ss and aa we can unroll the expression for ϕt+1\phi_{t+1} for nn steps

ϕt+1=(ϕt1+at1st1)+atst= ... =ϕtn+1+i=tn+1taii=tn+1tsi\phi_{t+1} = (\phi_{t-1} + a_{t-1} - s_{t-1}) + a_t - s_t = \ ...\ = \phi_{t-n+1} + \sum_{i=t-n+1}^t a_i - \sum_{i=t-n+1}^t s_i

This gives us a relation between a∑a and s∑s which is a good starting point.

i=tn+1tai=i=tn+1tsi  +ϕt+1ϕtn+1\sum_{i=t-n+1}^t a_i = \sum_{i=t-n+1}^t s_i \ \ + \phi_{t+1} - \phi_{t-n+1}

To get ata_t we have to assume at=constanta_t = \text{constant} over a series of timesteps tn+1,,t{t-n+1, …, t}. Then, we can write

at=1ni=tn+1tai=1n ⁣(i=tn+1tsi+ϕt+1ϕtn+1)=1ni=tn+1tsiwe can access+ϕt+1ϕtn+1nerror terma_t = \tfrac 1 {n} \sum_{i=t-n+1}^t a_i = \tfrac 1 {n} \!\left( \sum_{i=t-n+1}^t s_i + \phi_{t+1} - \phi_{t-n+1} \right) = \underbrace{\tfrac 1 n \sum_{i=t-n+1}^t s_i}_\text{we can access} + \underbrace{\tfrac{\phi_{t+1} - \phi_{t-n+1}} n}_\text{error term}

For nn \to ∞ we therefore have at=limn1ni=tn+1tsia_t = \lim_{n \to ∞} \frac 1 n \sum_{i=t-n+1}^t s_i. Since ϕt[0.5,0.5]ϕ_t ∈ [-0.5, 0.5] and 𝔼[ϕt]=0𝔼[ϕ_t] = 0 we can assume that error term is small even for small nn. The scale of the sum, on the other hand, is (up to the error term) proportional to aa. That means the signal-to-noise ratio of the reconstruction depends heavily on which scaling constant 1kp+kd\frac 1 {k_p + k_d} we use for aa.

Furthermore, the requirement for ata_t to be constant across many timesteps is not a real limitation. We can just increase time resolution and increase nn proportionally to make the error term small. So if xx changes too quickly we can just make our timesteps smaller.

To “decode” the quantization we therefore have to average the quantized signal. Conveniently the decoding scheme from the previous section already does this implicitly (approximately):

xt=ci=0t(kdkp+kd)tiaici=tntai=ci=tnt1nj=tntsj=ci=tntsici=0t(kdkp+kd)tisi=:x^t\begin{aligned} {\color{blue} x_t} &= c \sum_{i=0}^{t} \left(\frac{k_d}{k_p + k_d}\right)^{t-i} a_i \approx c \sum_{i=t-n}^t a_i = c \sum_{i=t-n}^t \frac{1}{n} \sum_{j=t-n}^t s_j \\ &= c \sum_{i=t-n}^t s_i \approx c \sum_{i=0}^{t} \left(\frac{k_d}{k_p + k_d}\right)^{t-i} s_i =: {\color{orangered}{\hat x_t}} \end{aligned}

Therefore we don’t need a decoder Q1Q^{-1} for the quantization such that we end up with the following pipeline.

xencQsdecx^\overset {\color{blue} x} ⟶ \text{enc} ⟶ \text{Q} \overset s ⟶ \text{dec} \overset {\color{orangered}{ \hat x}} ⟶

Below we can see what the combined signals look like for different encoding parameters.

Integer weight multiplication

Right now we have established more efficient communication between the neurons but still not incorporated the weight multiplication.

 ⁣ ⁣ ⁣ hxencQneuron a sQaxon  decx^wQsynapses    ⁣ ⁣ ⁣ ⁣ ⁣ ⁣zh Qneuron b (not what we want)\cdots \underbrace{ {\!\!\!⟶} \ h \overset {\color{blue} x} → \text{enc} → \text{Q} }_\text{neuron a} \ \underbrace{\overset {s} ⟶ \vphantom{Q}}_\text{axon} \ \ \underbrace{\text{dec} \overset {\color{orangered}{ \hat x}} → w \vphantom{Q}}_\text{synapses} \ \ \ \underbrace{\overset{\!\!\! z} {\!\!\!⟶} h \ \cdots \vphantom{Q}}_\text{neuron b } \quad \text{(not what we want)}

So we have

zt=wt x^t=wt ci=0t(kdkp+kd)tisiz_t = w_t \ \hat x_t = w_t \ c \sum_{i=0}^{t} \left(\frac {k_d } {k_p + k_d} \right)^{t-i} s_i

Considering that x^t\hat x_t is just a weighted sum, if we assume wt=constantw_t = \text{constant}, we can pull it inside the sum

ztci=0t(kdkp+kd)tisiwt:=z^tz_t ≈ c \sum_{i=0}^{t} \left(\frac {k_d } {k_p + k_d} \right)^{t-i} s_i w_t := \color{red}{\hat z_t}

Because sis_i is integer we have achieved our goal of replacing the floating point multiplication with a cheaper sparse integer multiplication! The approximation error we make with the assumption wt=constantw_t = \text{constant} depends on how fast we decay the weights inside the sum, i.e. how large kpk_p is. Below is the final pipeline and a plot of the reconstruction for different kpk_p.


 ⁣ ⁣ ⁣ hxencQneuron a   sQaxonwtQsynapses  ⁣ ⁣ ⁣ ⁣ ⁣ ⁣decz^h Qneuron b ()\cdots \underbrace{ {\!\!\!⟶} \ h \overset {\color{blue} x} → \text{enc} → \text{Q} }_\text{neuron a} \ \ \ \underbrace{\overset {s} ⟶ \vphantom{Q} }_\text{axon} \underbrace{\color{orange}{w_t} \vphantom{Q}}_\text{synapses} \ \underbrace{\overset{\!\!\! } {\!\!\!⟶} \text{dec} \overset{\color{red}{\hat z}} \to h \ \cdots \vphantom{Q}}_\text{neuron b } \quad (\checkmark )


Learning the weights

To learn the weights we can apply the same coding scheme for backpropagation (by making the same assumptions). The symmetric backward pass through the transposed weights is not really biologically plausible but there is orthogonal work on biologically plausible backpropagation.

 ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ ⁣ hxencQneuron a   xˉ=sQaxonwtQsynapses  ⁣ ⁣ ⁣ ⁣ ⁣ ⁣decz^h Qneuron b    (forward)\!\!\!\!\!\!\!\!\!\!\!\! \cdots \underbrace{ {\!\!\!⟶} \ h \overset {\color{blue} x} → \text{enc} → \text{Q} }_\text{neuron a} \ \ \ \underbrace{\overset {\bar x = s} ⟶ \vphantom{Q} }_\text{axon} \underbrace{{w_t} \vphantom{Q}}_\text{synapses} \ \underbrace{\overset{\!\!\! } {\!\!\!⟶} \text{dec} \overset{{\hat z}} \to h \ \cdots \vphantom{Q}}_\text{neuron b } \quad \quad \quad \quad \ \ \ \tag{forward}  ⁣ ⁣ ⁣ Qhdecneuron a   QaxonwTQsynapses  ⁣ ⁣ ⁣ ⁣ ⁣ ⁣eˉQenceh  Qneuron b ⁣ ⁣xL(backward)\cdots \underbrace{ {\!\!\!⟵} \ \vphantom{Q} h' \overset {} ← \text{dec} }_\text{neuron a} \ \ \ \underbrace{\overset {} ⟵ \vphantom{Q} }_\text{axon} \underbrace{{w^T} \vphantom{Q}}_\text{synapses} \ \underbrace{\overset{\!\!\! \bar e} {\!\!\!⟵} \text{Q} ← \text{enc} \overset{e} ← h' \ \ \vphantom{Q}}_\text{neuron b} \!\! \overset{∇_{x'} L} ⟵ \cdots \tag{backward}

This leads to an efficient backward pass but in order to update the weights with gradient descent we need to compute the outer product between the activations and the next pre-activation gradients: wL=xe∇_w L = x ⊗ e (where e=zLe = ∇_z L) which we both do not have access to.

The simplest solution would be to decode x^=dec(s)\hat x = \text{dec}(s) and e^=dec(eˉ)\hat e = \text{dec}(\bar e) before the inner product.

wL^recon=x^e^\widehat{∇_w L}_\text{recon} = \hat x ⊗ \hat e

Then we still have an expensive floating point multiplication, however. Instead we can use the fact that the result of the decoder sdecx^s → \text{dec} → \hat x

dec:    x^t=st+kdx^t1kp+kd\text{dec:} \ \ \ \ \hat x_t = \frac {s_t + k_d \hat x_{t-1}} {k_p + k_d}

decays exponentially in absence of spikes (i.e. st=0s_t = 0). Therefore we can calculate the sum over time between two spikes (pre-synaptic or post-synaptic) analytically as a sum over a geometric series. It is fine to sum the gradients over time and apply it as an update later, because that is what SGD does anyway.

i=tntx^ie^i=i=tnt(kdkp+kd)2(ti)x^tn e^tn=x^tne^tnj=0n( (kdkp+kd)2=r )j=x^tne^tn1  rn+11r\begin{aligned} \sum_{i = t-n}^t \hat x_i \hat e_i &= \sum_{i = t-n}^t \left(\tfrac {k_d}{k_p+k_d}\right)^{2(t-i)} \hat x_{t-n} \ \hat e_{t-n} \\& = \hat x_{t-n} \hat e_{t-n} \sum_{j = 0}^n \Big( \ \underbrace{\left(\tfrac {k_d}{k_p+k_d}\right)^{2}}_{=r} \ \Big)^{j} \\ &= \hat x_{t-n} \hat e_{t-n} \tfrac {1 \ - \ r^{n+1}}{1-r}\end{aligned}

Here, tnt-n is the time at which the last spike occurred. If another spike occurs for either xˉ\bar x or eˉ\bar e we just add that sum to the corresponding weight (multiplied by the learning rate). That is called “past updates” in the paper.

Summary

We looked at the forward pass, backward pass and weight updates from the paper and legitimized every step with solid math (most of which can also be found in the paper). This revealed the assumptions that had to be made and requirements on the hyperparameters as well as possible extension points to the method.