Gated Recurrent Unit (GRU) is a type of recurrent neural network (RNN) designed to capture long-term dependencies in sequential data efficiently. It is an extension of traditional RNNs and shares similarities with LSTM (Long Short-Term Memory) networks.
In this tutorial, we'll briefly learn about GRU model and how to implement sequential data prediction with GRU in PyTorch covering the following topics:- Introduction to GRU
- Data preparing
- Model definition and training
- Prediction
- Conclusion
Let's get started
Introduction to GRU
The key idea behind GRU is to address the vanishing gradient problem and improve the ability of RNNs to retain information over long sequences. GRU achieves this by introducing gating mechanisms that regulate the flow of information within the network.
A typical GRU unit consists of the following components:
- Update Gate determines how much of the past information to keep and how much new information to let through. It is calculated using the input at the current time step and the previous hidden state.
- Reset Gate controls which parts of the past hidden state should be ignored. It is calculated in a similar manner to the update gate.
- Candidate Activation computes a new candidate activation based on the current input and the previous hidden state, considering the reset gate.
- Hidden State combines the candidate activation with the update gate to produce the current hidden state.
The update and reset gates allow the model to selectively update or ignore information from previous time steps, addressing the vanishing gradient problem and facilitating the capture of long-range dependencies.
Data preparing
Let's implement sequence data prediction with GRU model in PyTorch. We start by loading the necessary libraries for this tutorial.
We use simple sequential data in this tutorial. Below code shows how to generate sequence data and visualize it on a graph. Here, we use 720 samples as a training data and 130 samples for test data to forecast.
Next, we convert data into training sequence and label with the given length. Below function helps us to create labels for sequence data.
We
can split data into train and test parts using forecast_start variable,
then generate sequence data and its labels. The np.reshape() function
reshapes data for LSTM input. Train and test sets are converted to
PyTorch tensors and DataLoader object is created using those tensors.
Model definition and training
We
define an GRU model using PyTorch's
nn.Module class. In the init method, we initialize the input, hidden,
and output sizes of the GRU model. The nn.GRU() method constructs the GRU layer with the specified input and hidden sizes, where
batch_first=True indicates that input and output tensors have the shape
(batch_size, sequence_length, input_size). Additionally, we define a
fully connected linear layer using the nn.Linear() method, which maps
the hidden state output of the GRU to the desired output size.
In
the forward method, we implement the forward pass through the gru
layer, generating an output tensor 'out'. Then, we apply the fully
connected layer to the last time step's output of the GRU (out[:, -1,
:]), producing the final output of the model.
We define hyperparameters for our model and initialize the model using the abvoe GRUModel class. We use MSELoss() as a loss function and Adam optimizer.
Next, we train model by iterating over the number of epochs and print the loss in every 10 epochs.
Now, we can start training the model.
Epoch [20/100], Loss: 4.0839
Epoch [30/100], Loss: 1.6807
Epoch [40/100], Loss: 0.5536
Epoch [50/100], Loss: 0.2236
Epoch [60/100], Loss: 0.1506
Epoch [70/100], Loss: 0.1338
Epoch [80/100], Loss: 0.1286
Epoch [90/100], Loss: 0.1256
Epoch [100/100], Loss: 0.1231
Prediction
After the training, we can predict test data by using trained model and visualize it in a graph.
Conclusion
GRU simplifies the architecture of traditional LSTM networks by combining the forget and input gates into a single update gate, making it computationally efficient in capturing temporal dependencies.
In
this tutorial, we learned about GRU networks and how to predict sequence data with GRU
model in PyTorch. Overview of GRU, data
preparation, GRU model definition, training, and
prediction of test data are explained in this tutorial. I hope this
tutorial will help you to understand GRU and its application in
sequential data.
No comments:
Post a Comment