SEMINAR

Learning to (Learn at Test Time) RNNs with Expressive Hidden States

Yoohwan Lee
2026.04.30
Natural Language Processing
Learning to (Learn at Test Time) RNNs with Expressive Hidden States
VENUE2025 ICML
PAPER LINKICML

Overview

  • sequence modeling에서 long context를 효율적으로 처리하기 위한 새로운 구조 제안
  • 기존 RNN은 fixed hidden state로 인해 long context 표현 한계 존재
  • Transformer는 long context 성능은 좋지만 계산 비용 증가
  • TTT(Test-Time Training) 기반으로 hidden state를 동적으로 업데이트하는 방식 제안

Key Takeaways

Problem Setting

  • sequence modeling의 핵심 문제: 과거 context를 어떻게 저장하고 활용할 것인가
  • RNN
    • hidden state에 정보를 압축하지만 fixed length로 표현력 제한
  • Transformer
    • KV cache로 context 저장하지만 길이에 따라 비용 증가
  • 기존 방법 한계
    • RNN: long context 표현 부족
    • Transformer: 계산 비용 및 확장성 문제
  • 목표: 효율적이면서 표현력 있는 long context modeling

Main Idea

  • hidden state를 고정 벡터가 아닌 학습 가능한 파라미터(가중치)로 확장
  • TTT Layer
    • 기존 sequence layer(RNN, attention)를 대체 가능
    • 동일한 interface 유지하면서 구조 교체 가능
    • test-time에 hidden state 역할을 하는 weight W를 업데이트
  • Self-supervised Update
    • 입력 xt마다 loss 계산 후 Wt 업데이트
    • Wt = Wt-1 - η∇l(Wt-1; xt)
    • 과거 context가 weight에 누적 저장
  • Inner / Outer Loop 구조
    • inner loop: TTT layer parameter W 업데이트 (test-time 포함)
    • outer loop: 나머지 네트워크 파라미터 θ 학습
    • W는 hidden state처럼 동작
  • Self-supervised Task
    • input을 training view / label view / test view로 분리
    • reconstruction 기반 loss로 self-supervised 학습
    • θK, θV, θQ는 outer loop에서 학습
  • Mini-batch TTT
    • sequential dependency 문제 해결을 위해 mini-batch 기반 업데이트
    • 일부 gradient 병렬화 가능
    • online GD → batch GD → mini-batch GD로 확장
  • Dual Formulation
    • 연산을 matmul 형태로 변환하여 GPU 효율 극대화
    • 기존 O(b·d²) 연산을 matmul 기반으로 최적화
    • 실제 구현에서 속도 향상 확인
  • Theoretical View
    • TTT layer는 다양한 sequence 모델을 포함하는 일반화된 형태
    • self-attention, linear attention 등과 이론적으로 연결

Result

  • 다양한 backbone(Mamba, Transformer)에서 TTT layer 적용 가능
  • short context(2k)에서는 기존 모델과 유사 성능
  • longer context(8k, 32k)에서 TTT layer 성능 이점 확인
  • wall-clock time 기준 Transformer 대비 유사하거나 더 빠른 경우 존재
  • FLOPs 증가 대비 실제 실행 시간 증가 제한적

Limitation

  • test-time 업데이트로 인한 추가 연산 존재
  • sequential dependency로 인해 병렬화 한계 존재
  • hyperparameter 및 outer loop 설계에 민감
  • 아직 대규모 모델 및 초장문 context에서 추가 검증 필요