Composants clés du DQN
Replay buffer
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
Q-Network
class DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, action_dim)
)
def forward(self, x): return self.net(x)
Points clés de la boucle d'entraînement
- Réseau cible : copier les poids tous les 1000 pas (évite les oscillations)
- Décroissance d'epsilon : linéaire de 1.0 à 0.01 sur 50K pas
- Taille de batch : 64, mise à jour tous les 4 pas
- Double DQN : réduit le biais de surestimation