Mastering Memory Tasks with World Models

ICLR 2024 (oral, top-1.2%)
1Mila 2Université de Montréal, 3Polytechnique Montréal, 4Dalhousie University, 5CIFAR AI Chair
Equal Contribution
Work done during his postdoc at Mila and Université de Montréal
Recall to Imagine

Recall to Imagine is a generalist, computationally efficient, model-based agent which shines in memory-intensive reinforcement learning (RL) tasks, particularly exhibiting superhuman performance in the most complex memory domain. Explore our open-sourced codebase, fully implemented in JAX for streamlined and highly efficient RL pipelines, along with advanced state space models.

Temporal Reasoning in Agents

Addressing sequential decision-making problems requires the incorporation of temporal reasoning. At the core of this process lies the function of memory. Essentially, memory enables humans and machines to retain valuable information from past events, empowering both living beings and digital entities to make informed decisions at present.

Take for instance the Memory Length task in the Behavior Suite benchmark. Within this environment, the goal is to output an action which is dictated by the initial observation (the episode length i.e., the memory steps number is an environment parameter). Thus, the agent must carry the information from the initial observation throughout the entire episode.

Another vital aspect of temporal reasoning is the ability to gauge the effects and consequences of our past actions and choices on the feedback/outcome we receive now - be it success or failure. This process is termed credit assignment. Delving deeper into this concept, consider the Discounting Chain task within the Behavior Suite, where the first action causes a reward that is only provided after a certain number of steps, specified by the parameter reward delay.

RL Struggles with Memory and Credit Assignment

RL, which is a reward-maximization paradigm to learn intelligent behaviors, has demonstrated remarkable achievements across a diverse set of applications ranging from real-world challenges to games. Nonetheless, when confronted with tasks that heavily rely on long-term memory or long-horizon credit assignment, success becomes very challenging for most RL algorithms. For example, DreamerV3 stands out as a scalable and general model-based RL algorithm that masters a wide range of applications with fixed hyperparameters. Yet, questions arise about its performance in tasks like Memory Length and Discounting Chain. The figures presented below illustrate a limitation in the DreamerV3 memory capability, as it achieves rewards exclusively within the confines of episodes no longer than 30 steps.

Dreamer in Bsuite

Recall to Imagine

RNNs or Transformers are commonly utilized as the backbone of agent's world models to integrate temporal reasoning into the agent; however, the challenge of long-term memory and credit assignment frequently arises. This challenge is attributed to the backbone network architecture's inadequate learning of long-range dependencies. Recent research indicates that state space models (SSMs) have the potential to replace them. SSMs exhibit the capability to capture dependencies in very long sequences more efficiently—with sub-quadratic complexity—and can be trained in parallel.

To improve temporal coherence, we introduce Recall to Imagine, or in short form R2I, which integrates SSMs in DreamerV3's world model, giving rise to what we term the Structured State-Space Model (S3M). The design of the S3M aims to achieve two primary objectives: capturing long-range relations in trajectories and ensuring fast computational performance in model-based RL. S3M achieves the desired speed through parallel computation during training and recurrent mode in inference time, which enables quick generation of imagined trajectories.

R2I Overview

Does R2I improve the memory?

Absolutely, R2I doesn't just improve memory, it shines! It excels in Memory Length and Discounting Chain tasks, significantly outperforming in the preservation of its learning ability across a wider range of varying environment complexities.

BSuite

But it does not end here; we also assess R2I more comprehensively. To evaluate R2I under more challenging conditions, we perform a study utilizing POPGym, a benchmark that provides a collection of RL environments, designed to assess various challenges related to POMDPs, such as navigation, noise robustness, and memory. We select the three most memory-intensive environments: RepeatPrevious, Autoencode, and Concentration. These environments require an optimal policy to memorize the highest number of events at each time step. Each environment in POPGym has three difficulty levels: Easy, Medium, Hard. In the memory environments of this study, the complexity is increased by the number of actions or observations that the agent should keep track of simultaneously. As illustrated in the following figure, R2I demonstrates the new SOTA performance in these memory-intensive tasks. These results in POPGym and BSuite indicate that R2I significantly pushes the memory limits.

POPGym

Long-term memory of R2I in complex 3D tasks

How does R2I perform in complex environments where success depends on very long-term memory as well as reasoning, exploration, and other skills? Can enhanced memory mechanism of R2I actually bolster overall performance in these scenarios? Yes, it can. The adept memory mechanism isn't just keeping up; it's setting the pace, surpassing human abilities!

To investigate this question, we conduct experiments in Memory Maze domain, which presents randomized 3D mazes where the egocentric agent is repeatedly tasked to navigate to one of multiple objects. For optimal speed and efficiency, the agent must retain information about the locations of objects, the maze's wall layout, and its own position. Each episode can extend for up to 4K environment steps. An ideal agent equipped with long-term memory only needs to explore each maze once, a task achievable in a shorter time than the episode's duration; subsequently, it can efficiently find the shortest path to reach each requested target. We trained and tested R2I and leading baslines on 4 existing maze sizes: 9x9, 11x11, 13x13, and 15x15. R2I consistently sets a new state-of-the-art by outperforming leading baselines in all of these environments, achieving comparable or higher levels of performance. Moreover, it has surpassed human-level abilities in solving 9x9, 11x11, and 13x13 mazes, showcasing its strong memory capabilities in these complex tasks.

Memory Maze
Memory Maze

Generality of R2I in non-memory domains

To confirm that R2I retains the generality of its predecessor, DreamerV3, we assess the performance of R2I on two widely used RL benchmarks: Atari and DMC. While neither domain demands memory for being solved, evaluating R2I on them is essential as we aim to ensure our agent's performance across a wide range of tasks that require different types of control. R2I maintains a performance similar to DreamerV3 in these domains, as illustrated in the figure below. This suggests that, for the majority of standard RL tasks, R2I does not sacrifice generality for improved memory capabilities.

Atari and DMC

Computational efficiency of R2I

To enable the scalability and the parallelizability of world model learning, we employ parallel scan to execute SSM computations. Parallel scan enables scaling of sequence length in batch across distributed devices, a capability not supported by the convolution mode. During the world model's training phase, SSMs process sequences in parallel (unlike their RNN counterparts). When it comes to the training phase for the actor and the critic, the model transitions to a recurrent mode, removing the need to engage with the entire context as Transformers require. The result is the impressive computational efficiency of R2I, with a speed increase of up to 9 times compared to DreamerV3.

Atari and DMC

Conclusion

We introduced R2I, the first model-based approach to RL that uses SSM. R2I is a general and fast method that demonstrates superior memory capabilities, even transcending human performance in the complex Memory Maze domain.

Further enhancing its contribution to the field, we have released and documented the code, ensuring that R2I can be readily utilized and adapted by researchers. Its ease of use and proven effectiveness make R2I an ideal candidate for benchmarking and advancing the state-of-the-art in RL research.

BibTeX

@inproceedings{
    samsami2024mastering,
    title={Mastering Memory Tasks with World Models},
    author={Mohammad Reza Samsami and Artem Zholus and Janarthanan Rajendran and Sarath Chandar},
    booktitle={The Twelfth International Conference on Learning Representations},
    year={2024},
    url={https://openreview.net/forum?id=1vDArHJ68h}
}