Exercise: Sampling from an RNN¶

The goal of sampling from an RNN is to initialize the sequence in some way, feed it into the recurrent computation, and retrieve the next prediction.

To start, we create the initial vectors:

start_index = vectorizer.surname_vocab.start_index
batch_size = 2
# hidden_size = whatever hidden size the model is set to

initial_h = torch.ones(batch_size, hidden_size)
initial_x_index = torch.ones(batch_size).long() * start_index


Then, we need to use these vectors to retrieve the next prediction:

# model is stored in variable called net

x_t = net.emb(initial_x_index)
print(x_t.shape)
h_t = net.rnn._compute_next_hidden(x_t, initial_h)

y_t = net.fc(h_t)


Now that we have a prediction vector, we can create a probability distribution and sample from it. Note we include a temperature hyper parameter for controlling how strongly we sample from the distribution (at high temperatures, everything is uniform, at low temperatures below 1, small differences are magnified). The temperature is always greater than 0.

temperature = 1.0
y_t = F.softmax(y_t / temperature, dim=1)
x_index_t = torch.multinomial(y_t, 1)[:, 0]


Now we can start the cycle over again:

x_t = net.emb(x_index_t)
h_t = net.rnn._compute_next_hidden(x_t, h_t)

y_t = net.fc(h_t)


Write a for loop which repeats this sequence and appends the x_t variable to a list.

Then, we can do the following:

final_x_indices = torch.stack(x_indices).squeeze().permute(1, 0)

# stop here if you don't know what cpu, data, and numpy do. Ask away!
final_x_indices = final_x_indices.cpu().detach().numpy()

# loop over the items in the batch
results = []
for i in range(len(final_x_indices)):
tokens = []
index_vector = final_x_indices[i]
for x_index in index_vector:
if vectorizer.surname_vocab.start_index == x_index:
continue
elif vectorizer.surname_vocab.end_index == x_index:
break
else:
token = vectorizer.surname_vocab.lookup(x_index)
tokens.append(token)

sampled_surname = "".join(tokens)
results.append(sampled_surname)
tokens = []