Q-learning은 가장 기본적인 value-based 방법으로, 주어진 state에서 어떤 action을 선택해야 가장 큰 reward을 받을 수 있을지를 학습한다. 그러나 Q-table을 직접 저장하는 방식은 state-action 공간이 커질수록 비효율적이고, 이미지처럼 고차원 입력을 다룰 수 없다는 근본적인 한계를 가진다. 이러한 한계를 극복하기 위해 나온 것이 바로 Deep Q-Network (DQN) 이다. DQN은 Q-learning에 딥러닝을 접목시켜, 신경망을 통해 Q값을 근사함으로써 복잡한 환경에서도 policy를 학습할 수 있게 만든다.
DQN은 이를 해결하기 위해 Q값을 출력하는 DNN을 도입한다. 즉, 입력으로는 state를, 출력으로는 가능한 각 action에 대한 Q값 $Q(s,a)$을 뽑는 구조다. 이때 Q-learning 업데이트 식은 그대로 유지되는데, 중요한 건 그 Q값을 구하는 함수가 신경망이라는 점이다.
$$Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha \left[ r_t + \gamma max_a \,Q(s_{t+1},a) - Q(s_t,a_t)\right]$$
단순히 신경망으로 Q값을 근사한다고 해서 학습이 잘 되는 건 아니다. DQN은 아래 두 가지 핵심적인 기법을 통해 학습을 안정화시킨다.
1. Experience Replay
강화학습은 시계열 데이터 기반으로 학습이 진행되는데, 연속된 state들은 서로 강하게 상관되어 있다. 만약 이런 state들을 그대로 학습에 사용하면 신경망이 overfitting되거나, 학습이 불안정해질 수 있다. 이를 방지하기 위해 DQN은 agent가 경험한 transition을 replay buffer에 저장하고, 이 중에서 무작위로 sampling하여 학습한다. 이 과정을 통해 데이터 간 correlation을 줄이고, 학습 효율을 높일 수 있다.
2. Target Network
Q-learning의 목표는 현재 Q값이 TD target 값과 가까워지도록 만드는 것이다. 그런데 이 TD target 역시 학습 중인 신경망에서 나오기 때문에, network가 불안정하게 진동할 수 있다. 이를 해결하기 위해 DQN은 target network를 도입한다. 이 network는 일정 주기마다 main network의 parameter를 복사해 사용하는 것으로, target Q값을 더 안정적으로 계산할 수 있게 해준다.
알고리즘은 아래와 같다.
위 알고리즘을 바탕으로 좀더 복잡한 state-action 공간을 갖는 Cart Pole 예제에 적용해보았다.
https://www.gymlibrary.dev/environments/classic_control/cart_pole/
Cart Pole - Gym Documentation
Previous Acrobot
www.gymlibrary.dev
아래는 이 알고리즘을 기준으로 실제 코드 구현에서 주목할 만한 몇 가지 포인트이다.
Object
- Q-network (Main Network)
- Q is initialized with random weights θ.
알고리즘 초반부에서 Q-network는 임의의 가중치로 초기화된다. 이후 매 step마다 이 Q 네트워크를 통해 행동을 선택하고, 손실 함수를 기반으로 업데이트된다.
- Q is initialized with random weights θ.
- Target Network
- Q̂ is initialized with θ⁻ = θ, and updated every C steps.
타겟 네트워크는 초기에는 Q와 동일하지만, 일정 주기(C 스텝)마다 Q의 가중치를 복사해온다. 학습의 안정성을 확보하기 위한 핵심 장치다.
- Q̂ is initialized with θ⁻ = θ, and updated every C steps.
- Replay Buffer
- D is a replay memory of capacity N. Transitions are stored as (ϕₜ, aₜ, rₜ, ϕₜ₊₁).
매 timestep마다 현재 transition을 replay buffer에 저장하고, 미니배치로 샘플링하여 네트워크를 학습시킨다. 이 과정을 통해 시계열 의존성 문제를 완화하고 데이터 효율을 높인다.
- D is a replay memory of capacity N. Transitions are stored as (ϕₜ, aₜ, rₜ, ϕₜ₊₁).
Logic
- Epsilon-Greedy 정책으로 행동 선택
- With probability ε select a random action, otherwise use argmax_a Q(ϕ(sₜ), a; θ)
exploration과 exploitation을 균형 있게 하기 위해 ε-greedy 정책을 사용한다. 학습 초반에는 탐험 위주로, 이후에는 점점 정책에 의존하게 조정된다.
- With probability ε select a random action, otherwise use argmax_a Q(ϕ(sₜ), a; θ)
- TD Target 계산 & 손실 함수 정의
- yj={rjif terminalrj+γ⋅maxa′Q^(ϕj+1,a′;θ−)otherwisey_j = \begin{cases} r_j & \text{if terminal} \\ r_j + \gamma \cdot \max_{a'} Q̂(ϕ_{j+1}, a'; θ⁻) & \text{otherwise} \end{cases}DQN의 핵심 업데이트 식이다. 현재 Q값이 TD 타겟과 가까워지도록 손실 함수를 최소화한다.
- Gradient Descent로 파라미터 업데이트
- Perform a gradient descent step on: (yj−Q(ϕj,aj;θ))2\left( y_j - Q(ϕ_j, a_j; θ) \right)^2손실 함수에 대해 미분하여 Q-network의 파라미터를 업데이트하는 과정이다.
- Target Network 업데이트
- Every C steps, Q̂ ← Q
일정 주기마다 타겟 네트워크를 최신 Q 네트워크로 갱신한다. 이로써 학습은 더 안정적으로 수렴한다.
- Every C steps, Q̂ ← Q
자세한 내용은 아래 깃허브 링크에서 확인할 수 있다.
Github : https://github.com/seonvin0319/25RL_Study/tree/main/03DQN
'인공지능' 카테고리의 다른 글
[강화학습] TD: SARSA vs Q-learning (1) | 2025.04.08 |
---|---|
[강화학습] Monte Carlo & Temporal Difference (0) | 2025.04.05 |
[강화학습] Dynamic Programming (0) | 2025.04.05 |
[강화학습] Markov Decision Process (0) | 2025.04.03 |
[Graph Neural Network] Traditional Feature-based Method (0) | 2024.08.22 |
댓글