Back to Blog
Machine Learning January 18, 2025 13 min read

DQN from Scratch: Teaching an Agent to Play Snake

A complete from-scratch DQN implementation in PyTorch — environment, replay buffer, epsilon-greedy exploration, and the training loop that actually converges.

DQN Core Components

Replay Buffer

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

Q-Network

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, action_dim)
        )
    def forward(self, x): return self.net(x)

Training Loop Key Points

  • Target network: copy weights every 1000 steps (prevents oscillation)
  • Epsilon decay: linear from 1.0 to 0.01 over 50K steps
  • Batch size: 64, update every 4 steps
  • Double DQN: reduces overestimation bias
Reinforcement LearningDQNPyTorchGame AIDeep Q-Network
O

Ossama Elhakki

AI Engineer & ML Systems Builder — Morocco