📚 LSTM Implementation
This document reviews the main themes and key takeaways from Deep Learning Systems: Algorithms and Implementation** at Carnegie Mellon University, taught by J. Zico Kolter and Tianqi Chen.
Lecture Overview
The lecture focuses on implementing a Long Short-Term Memory (LSTM) network, starting with a single cell and then expanding to a full sequence and finally to batched implementation. The primary goal is to demonstrate the inner workings of LSTMs and show that they aren’t as complicated as they might seem. The lecturer begins by showing how to create an LSTM in PyTorch and then goes on to implement the same functionality using NumPy. This approach allows for a comparison between the high-level PyTorch implementation and a lower-level NumPy version. The lecture then moves to implementing a batched version of the LSTM, and finally discusses the practicalities of training an LSTM using truncated backpropagation through time.
Key Implementation Steps
- Single LSTM Cell Implementation đź§
- The lecture begins by implementing a single LSTM cell using NumPy.
- This involves defining the sigmoid and tanh non-linearities.
sigmoid(x) = 1 / (1 + exp(-x))
- The LSTM cell takes as input an input vector x, the previous hidden state h, the previous cell state c, weights
W_hh,W_ihand bias b. -
It computes the intermediate vectors i, f, g, and o (input gate, forget gate, cell gate, output gate) using a combined matrix multiplication of inputs and previous hidden state, and then splits the resulting vector into four parts, each of which has a hidden dimension d:
intermediate = W_hh * h + W_ih * x + b i, f, g, o = split(intermediate) # each of size d - The non-linearities are applied to these vectors:
i = sigmoid(i) f = sigmoid(f) g = tanh(g) o = sigmoid(o) - The cell state is updated using the forget gate, input gate, and cell gate, using element wise multiplication:
c_out = f * c + i * g - The output hidden state is computed using the output gate and the tanh of the cell state, again with element wise multiplication:
h_out = o * tanh(c_out) - The function returns the new hidden state, and cell state.
- The lecturer then tests if the single cell implementation is the same as the PyTorch version, using randomly initialized inputs, and weights copied from the PyTorch implementation. The results are shown to be the same to numerical precision.
- Full Sequence LSTM Implementation 🎬
- The next step is to implement an LSTM that operates on a sequence of inputs, rather than a single one.
- This is achieved by iteratively applying the single LSTM cell to each element of the input sequence.
- The implementation initializes an array to store all the hidden states over time and then iterates over the time steps of the input sequence.
- In each step it calls the single LSTM cell, computes the new hidden state and cell state, saves the hidden state, and continues until the end of the sequence.
- The function returns all the hidden states for the sequence, and the final cell state.
- The lecturer then tests if this implementation is the same as PyTorch’s implementation. The results are shown to be the same to numerical precision.
- Batched LSTM Implementation 🤹
- The lecture also addresses the need for batched processing for efficiency.
- The key difference with the batched implementation is the data layout. Instead of having batch dimension first, it has the time dimension first.
- The input tensor has the shape
(time_steps, batch_size, input_dimension).
- The input tensor has the shape
- The LSTM cell implementation is modified to split the intermediate state along the hidden dimension (axis=1), to ensure the split operation is done correctly for batched operations.
- The implementation is otherwise the same as the sequence implementation, but the input, output, and hidden states are of shape
(time_steps, batch_size, hidden_dimension) - The lecturer tests this batched implementation and shows that it matches PyTorch to numerical precision.
- Training LSTMs 🏋️
- The lecturer then describes how to train an LSTM, first conceptually using all the sequence, and then explains how to efficiently train it.
- Training an LSTM involves computing a loss function between the hidden states and some target values, and using backpropagation to update the weights, and biases.
- The lecture notes that to train deep LSTMs, we can use the hidden state output of a previous layer of an LSTM as the input to the next layer.
- In this case we would first compute the entire hidden state of the first layer, before passing it to the second layer.
- Because RNNs, including LSTMs are often trained on large sequences, the lecture also emphasizes the importance of truncated backpropagation through time.
- This involves splitting the entire sequence into smaller blocks and training the LSTM on each block separately, to limit memory usage.
- The lecturer describes the technique of hidden unit repackaging, where the last hidden state and cell state of a previous block are used as the initial state of the next block.
- Critically, the hidden state is detached from the computation graph before it is used in the subsequent block, so gradients do not propagate between the blocks.
Code Examples (Conceptual - NumPy style)
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def tanh(x):
return np.tanh(x)
def lstm_cell(x, h, c, W_hh, W_ih, b):
"""
Implements a single LSTM cell.
Args:
x: input vector
h: previous hidden state
c: previous cell state
W_hh: weight matrix for hidden to hidden connections
W_ih: weight matrix for input to hidden connections
b: bias
Returns:
h_out: new hidden state
c_out: new cell state
"""
intermediate = np.dot(W_hh, h) + np.dot(W_ih, x) + b # combined matrix multiplication
i, f, g, o = np.split(intermediate, 4, axis = 0) # splitting intermediate state
i = sigmoid(i)
f = sigmoid(f)
g = tanh(g)
o = sigmoid(o)
c_out = f * c + i * g
h_out = o * tanh(c_out)
return h_out, c_out
def lstm(X, h, c, W_hh, W_ih, b):
"""
Implements an LSTM over a sequence
Args:
X: input sequence of shape (time_steps, input_dim)
h: initial hidden state of size (hidden_dim,)
c: initial cell state of size (hidden_dim,)
W_hh: Weight matrix (4*hidden_dim, hidden_dim)
W_ih: Weight matrix (4*hidden_dim, input_dim)
b: bias (4*hidden_dim,)
Returns:
H: all hidden states of the sequence (time_steps, hidden_dim)
c_out: Final cell state of size (hidden_dim,)
"""
T = X.shape
H = np.zeros((T, h.shape)) # stores all the hidden states
for t in range(T):
h, c = lstm_cell(X[t], h, c, W_hh, W_ih, b)
H[t] = h
return H, c # returns all hidden states, and the final cell state.
def lstm_cell_batched(x, h, c, W_hh, W_ih, b):
"""
Batched version of lstm cell
Input
x: input of size (batch_size, input_dimension)
h: previous hidden state (batch_size, hidden_dimension)
c: previous cell state (batch_size, hidden_dimension)
W_hh: weight matrix (4*hidden_dim, hidden_dim)
W_ih: weight matrix (4*hidden_dim, input_dim)
b: bias (4*hidden_dim,)
Output
h_out: new hidden state
c_out: new cell state
"""
intermediate = np.dot(h, W_hh.T) + np.dot(x, W_ih.T) + b # transposed weights
i, f, g, o = np.split(intermediate, 4, axis = 1) # split along the hidden dimension
i = sigmoid(i)
f = sigmoid(f)
g = tanh(g)
o = sigmoid(o)
c_out = f * c + i * g
h_out = o * tanh(c_out)
return h_out, c_out
def lstm_batched(X, h, c, W_hh, W_ih, b):
"""
Implements batched LSTM over a sequence
Args:
X: input sequence of shape (time_steps, batch_size, input_dim)
h: initial hidden state of shape (batch_size, hidden_dim)
c: initial cell state of shape (batch_size, hidden_dim)
W_hh: Weight matrix (4*hidden_dim, hidden_dim)
W_ih: Weight matrix (4*hidden_dim, input_dim)
b: bias (4*hidden_dim,)
Returns:
H: all hidden states of the sequence (time_steps, batch_size, hidden_dim)
c_out: final cell state of shape (batch_size, hidden_dim)
"""
T = X.shape
batch_size = X.shape[1]
hidden_dim = h.shape[1]
H = np.zeros((T, batch_size, hidden_dim))
for t in range(T):
h, c = lstm_cell_batched(X[t], h, c, W_hh, W_ih, b)
H[t] = h
return H, c
# Example Usage (Conceptual)
input_dim = 20
hidden_dim = 100
batch_size = 128
time_steps = 50
# Random Initializations
X = np.random.randn(time_steps, batch_size, input_dim).astype(np.float32)
h0 = np.random.randn(batch_size, hidden_dim).astype(np.float32)
c0 = np.random.randn(batch_size, hidden_dim).astype(np.float32)
W_ih = np.random.randn(4 * hidden_dim, input_dim).astype(np.float32)
W_hh = np.random.randn(4 * hidden_dim, hidden_dim).astype(np.float32)
b = np.random.randn(4 * hidden_dim).astype(np.float32)
# Run the Batched LSTM
H, C = lstm_batched(X, h0, c0, W_hh, W_ih, b)
print("Shape of all hidden states is:", H.shape)
Key Points
- The lecture provides a step-by-step implementation of LSTMs, starting with the fundamental cell and progressing to batch processing.
- The lecture uses NumPy to show how the LSTM works at a low level.
- The batched implementation and hidden state repackaging are crucial for training on large datasets efficiently.
- The lecturer notes that the specific equations used in an LSTM are not the only way to do things, but have been carefully tuned to achieve good performance.
Code Example (PyTorch)
import torch
import torch.nn as nn
import numpy as np
# --- Single LSTM Cell Example ---
# Define the LSTM cell
input_size = 20
hidden_size = 100
lstm_cell = nn.LSTMCell(input_size, hidden_size)
# Generate random input, hidden, and cell states
x = torch.randn(input_size).float()
h0 = torch.randn(hidden_size).float()
c0 = torch.randn(hidden_size).float()
# Pass the input through the LSTM cell
h1, c1 = lstm_cell(x.unsqueeze(0), (h0.unsqueeze(0), c0.unsqueeze(0))) # unsqueeze to make batch dimension explicit
print("Shape of the hidden state is", h1.shape)
print("Shape of the cell state is", c1.shape)
# --- Sequence LSTM Example ---
# Define the LSTM layer
lstm_layer = nn.LSTM(input_size, hidden_size, num_layers = 1)
# Generate a random sequence of inputs
time_steps = 50
batch_size = 128
X = torch.randn(time_steps, batch_size, input_size).float()
H0 = torch.randn(1, batch_size, hidden_size).float()
C0 = torch.randn(1, batch_size, hidden_size).float()
# Pass the input through the LSTM layer
output, (hn, cn) = lstm_layer(X, (H0, C0))
print("Shape of the hidden states for the whole sequence is", output.shape)
print("Shape of the final hidden state is", hn.shape)
print("Shape of the final cell state is", cn.shape)
# Accessing the weights
print("Shape of the weight_ih matrix is:", lstm_cell.weight_ih.shape)
print("Shape of the weight_hh matrix is:", lstm_cell.weight_hh.shape)
print("Shape of the weight_ih matrix of the LSTM layer is:", lstm_layer.weight_ih_l0.shape)
print("Shape of the weight_hh matrix of the LSTM layer is:", lstm_layer.weight_hh_l0.shape)
Key Takeaways
- PyTorch’s nn.LSTMCell provides a basic building block for LSTMs, that processes a single time step.
- PyTorch’s nn.LSTM can process an entire sequence of inputs and is the preferred way of implementing LSTMs for most practical uses.
- PyTorch combines the weights and biases into single matrices and tensors, which may appear different from the equations found in some resources.
- The input format for PyTorch LSTMs includes having the time dimension first, and that the initial hidden and cell states must be passed in as a tuple.
- The lecture uses the PyTorch implementation as a reference to compare against when implementing LSTMs in NumPy.
