Introducing TorchOpt: A High-Performance Differentiable Optimization Library for PyTorch
Explore TorchOpt, a PyTorch-based library that revolutionizes differentiable optimization with its unified programming abstraction, high-performance distributed execution runtime, and support for various differentiation modes.”
This post is authored by Bo Liu, a Ph.D. student at National University of Singapore in Department of Computer Science. He was one of the members of the MetaOPT Team. The team also includes Jie Ren, Xidong Feng, Xuehai Pan, Luo Mai, Yaodong Yang.
Introducing TorchOpt
The realm of machine learning (ML) has been transformed by differentiable programming, which facilitates automatic computation of derivatives within a high-level language. Its widespread application, from the backpropagation of neural networks to Bayesian inference and probabilistic programming, has significantly powered the progress of ML and its applications. It has enabled efficient and composable automatic differentiation (AD) tools, paving the way for advancements in differentiable optimization [1, 2], simulators [3, 4], engineering [5], and science [6]. The burgeoning number of differentiable optimization algorithms has underscored the essential role of differentiable programming.
Enter TorchOpt — an efficient library for differentiable optimization that builds upon PyTorch. TorchOpt is available on GitHub at https://github.com/metaopt/torchopt.
TorchOpt offers:
- Versatility: TorchOpt encompasses three differentiation modes — explicit differentiation, implicit differentiation, and zero-order differentiation, catering to various differentiable optimization needs.
- Flexibility: TorchOpt delivers a functional and objective-oriented API to cater to different user preferences. You can implement differentiable optimization in a style akin to JAX or PyTorch.
- Efficiency: TorchOpt offers CPU/GPU-accelerated differentiable optimizers, an RPC-based distributed training framework, and fast tree operations, dramatically enhancing training efficiency for bi-level optimization problems.
Why TorchOpt?
TorchOpt melds two pivotal facets — a unified and expressive differentiable optimization programming abstraction and a high-speed distributed execution runtime.
TorchOpt presents an abstraction that promotes the efficient definition and analysis of differentiable optimization programs, accommodating explicit, implicit, and zero-order gradients.
TorchOpt offers a diverse set of low-level, high-level, functional, and Object-Oriented (OO) APIs to enable users to incorporate differentiable optimization within the computational graphs produced by PyTorch. Specifically, TorchOpt supports three differentiation modes for handling differentiable optimization problems:
(i) Explicit gradient for unrolled optimization,
(ii) Implicit gradient for solution-based iterative optimization,
(iii) Zero-order gradient estimation for non-smooth/non-differentiable functions.
TorchOpt offers high-performance and distributed execution runtime containing several accelerated solutions to support fast differentiation with different modes on GPU & CPU and distributed training features for multi-node multi-GPU. The figures below show the comparison of TorchOpt with other baselines with CPU/GPU-accelerated op and distributed training.
For PyTorch researchers and developers, TorchOpt’s features enable efficient declaration and analysis of various differentiable optimization programs, complete parallelization of computation-intensive differentiation operations, and automatic distribution of computation to distributed devices.
Usage Examples
Let’s delve into two specific usage examples of TorchOpt. We’ll guide you through each step, providing visuals or code examples for better comprehension.
A warm-up example for differentiable optimizers
Let us start with a warm-up example.
Given the analytical solution above, let’s validate it using MetaOptimizer
in TorchOpt. MetaOptimizer
is our differentiable optimizer's main class. It combines with the functional optimizers torchopt.sgd
and torchopt.adam
to define our high-level APIs torchopt.MetaSGD
and torchopt.MetaAdam
.
Let us start. First, define the network.
from IPython.display import display
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.tensor(1.0), requires_grad=True)
def forward(self, x):
return self.a * (x**2)
Then we declare the network (parameterized by a
) and the meta-parameter x
. Do not forget to set flag requires_grad=True
for x
.
net = Net()
x = nn.Parameter(torch.tensor(2.0), requires_grad=True)
Next we declare the meta-optimizer. Here we show two equivalent ways of defining the meta-optimizer.
# Low-level API
optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))
# High-level API
optim = torchopt.MetaSGD(net, lr=1.0)
The meta-optimizer takes the network as input and use method step
to update the network (parameterized by a
). Finally, we show how a bi-level process works.
inner_loss = net(x)
optim.step(inner_loss)
outer_loss = net(x)
outer_loss.backward()
# x.grad = - 4 * lr * x^3 + 2 * a_0 * x
# = - 4 * 1 * 2^3 + 2 * 1 * 2
# = -32 + 4
# = -28
print(f'x.grad = {x.grad!r}')
The output is:
x.grad = tensor(-28.)
Implementing Model-Agnostic Meta-Learning (MAML) Using TorchOpt
Let us start with the core idea of the Model-Agnostic Meta-Learning (MAML) algorithm. MAML is an algorithm for meta-learning that is model-agnostic, in the sense that it is compatible with any model trained with gradient descent and applicable to a variety of different learning problems, including classification, regression, and reinforcement learning. The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples.
In the MAML approach, a model is trained on a variety of tasks, and then fine-tuned with one or a few gradient steps computed on a small amount of data from a new task. The key insight of MAML is to train the initial model such that these few steps of fine-tuning lead to good generalization performance on the new task.
The update rule in MAML is defined as:
Given a learning rate alpha
for the fine-tuning step, theta
should minimize
Optimizing this objective is the goal of the meta-training procedure. Given a task i
from the distribution of tasks, p(T)
, the model parameters, theta
, are updated using one or more gradient descent steps on the loss L_i
of task i
, resulting in the task-specific parameters theta_i'
. The update rule is written as theta_i' = theta - alpha * grad(L_i(theta))
, where alpha
is the learning rate and grad
denotes the gradient.
After this update for each task in the batch, the model parameters theta
are updated using gradient descent on the sum of the losses L_i
of all tasks i
in the batch, with the loss L_i
computed using the task-specific parameters theta_i'
. This update rule is written as theta = theta - beta * grad(sum_i(L_i(theta_i'))
, where beta
is the learning rate and grad
denotes the gradient.
Here, alpha
and beta
are hyperparameters that determine the step size of the gradient descent updates. The learning rate alpha
is typically chosen to be small so that the model can adapt quickly to each task, while the learning rate beta
is typically chosen to be large so that the model can learn effectively from the distribution of tasks.
Now, let’s explain the provided code example of implementing the MAML algorithm in reinforcement learning with TorchOpt.
We start by defining some parameters related to the tasks, trajectories, states, actions, and iterations.
import argparse
from typing import NamedTuple
import gym
import numpy as np
import torch
import torch.optim as optim
import torchopt
from helpers.policy import CategoricalMLPPolicy
TASK_NUM = 40
TRAJ_NUM = 20
TRAJ_LEN = 10
STATE_DIM = 10
ACTION_DIM = 5
GAMMA = 0.99
LAMBDA = 0.95
outer_iters = 500
inner_iters = 1
Next, we define a class named Traj
to represent a trajectory, which includes the observed states, actions taken, the states observed after taking the actions, the rewards obtained, and the gamma values for discounting future rewards.
class Traj(NamedTuple):
obs: np.ndarray
acs: np.ndarray
next_obs: np.ndarray
rews: np.ndarray
gammas: np.ndarray
We then define a function sample_traj
to generate a trajectory given the environment, task, policy, and parameters. This function simulates the interaction between the policy and the environment for TRAJ_LEN
steps.
def sample_traj(env, task, policy):
env.reset_task(task)
obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8)
rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
with torch.no_grad():
for batch in range(TRAJ_NUM):
ob = env.reset()
for step in range(TRAJ_LEN):
ob_tensor = torch.from_numpy(ob)
pi, _ = policy(ob_tensor)
ac_tensor = pi.sample()
ac = ac_tensor.cpu().numpy()
next_ob, rew, done, info = env.step(ac)
obs_buf[step][batch] = ob
next_obs_buf[step][batch] = next_ob
acs_buf[step][batch] = ac
rews_buf[step][batch] = rew
gammas_buf[step][batch] = (1 - done) * GAMMA
ob = next_ob
return Traj(
obs=obs_buf,
acs=acs_buf,
next_obs=next_obs_buf,
rews=rews_buf,
gammas=gammas_buf,
)
The a2c_loss
function is used to compute the loss for the Actor-Critic (A2C) algorithm. The A2C algorithm is a type of policy gradient method that uses a value function (the critic) to reduce the variance of the policy gradient (the actor).
def a2c_loss(traj, policy, value_coef):
lambdas = np.ones_like(traj.gammas) * LAMBDA
_, next_values = policy(torch.from_numpy(traj.next_obs))
next_values = torch.squeeze(next_values, -1).detach().numpy()
# Work backwards to compute `G_{T-1}`, ..., `G_0`.
returns = []
g = next_values[-1, :]
for i in reversed(range(next_values.shape[0])):
g = traj.rews[i, :] + traj.gammas[i, :] * (
(1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g
)
returns.insert(0, g)
lambda_returns = torch.from_numpy(np.array(returns))
pi, values = policy(torch.from_numpy(traj.obs))
log_probs = pi.log_prob(torch.from_numpy(traj.acs))
advs = lambda_returns - torch.squeeze(values, -1)
action_loss = -(advs.detach() * log_probs).mean()
value_loss = advs.pow(2).mean()
loss = action_loss + value_coef * value_loss
return loss
The evaluate
function is used to evaluate the performance of the policy on different tasks. It uses the inner optimizer to fine-tune the policy on each task and then computes the rewards before and after the fine-tuning.
def evaluate(env, seed, task_num, policy):
pre_reward_ls = []
post_reward_ls = []
inner_opt = torchopt.MetaSGD(policy, lr=0.1)
env = gym.make(
'TabularMDP-v0',
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
tasks = env.sample_tasks(num_tasks=task_num)
policy_state_dict = torchopt.extract_state_dict(policy)
optim_state_dict = torchopt.extract_state_dict(inner_opt)
for idx in range(task_num):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss)
post_trajs = sample_traj(env, tasks[idx], policy)
# Logging
pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
torchopt.recover_state_dict(policy, policy_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
return pre_reward_ls, post_reward_ls
In the main
function, we initialize the environment, policy, and optimizers. The policy is a simple MLP that outputs a categorical distribution over the actions. The inner optimizer is used to update the policy parameters during the fine-tuning phase, and the outer optimizer is used to update the policy parameters during the meta-training phase. The performance is evaluated by the rewards before and after the fine-tuning. The training process is logged and printed for each outer iteration.
def main(args):
# init training
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Env
env = gym.make(
'TabularMDP-v0',
num_states=STATE_DIM,
num_actions=ACTION_DIM,
max_episode_steps=TRAJ_LEN,
seed=args.seed,
)
# Policy
policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
inner_opt = torchopt.MetaSGD(policy, lr=0.1)
outer_opt = optim.Adam(policy.parameters(), lr=1e-3)
train_pre_reward = []
train_post_reward = []
test_pre_reward = []
test_post_reward = []
for i in range(outer_iters):
tasks = env.sample_tasks(num_tasks=TASK_NUM)
train_pre_reward_ls = []
train_post_reward_ls = []
outer_opt.zero_grad()
policy_state_dict = torchopt.extract_state_dict(policy)
optim_state_dict = torchopt.extract_state_dict(inner_opt)
for idx in range(TASK_NUM):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss)
post_trajs = sample_traj(env, tasks[idx], policy)
outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
outer_loss.backward()
torchopt.recover_state_dict(policy, policy_state_dict)
torchopt.recover_state_dict(inner_opt, optim_state_dict)
# Logging
train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
outer_opt.step()
test_pre_reward_ls, test_post_reward_ls = evaluate(env, args.seed, TASK_NUM, policy)
train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)
print('Train_iters', i)
print('train_pre_reward', sum(train_pre_reward_ls) / TASK_NUM)
print('train_post_reward', sum(train_post_reward_ls) / TASK_NUM)
print('test_pre_reward', sum(test_pre_reward_ls) / TASK_NUM)
print('test_post_reward', sum(test_post_reward_ls) / TASK_NUM)
In summary, this code example shows how to implement the MAML algorithm for reinforcement learning tasks using TorchOpt. The MAML algorithm is implemented in a flexible way that is compatible with any model trained with gradient descent, making it a powerful tool for meta-learning tasks.
Forward-Looking Statement
TorchOpt is a novel and efficient differentiable optimization library for PyTorch. Our experimental results highlight TorchOpt’s potential as a user-friendly, high-performance, and scalable library for supporting challenging gradient computation with PyTorch. We plan to support more complex differentiation modes and cover more non-trivial gradient computation problems in the future. TorchOpt has already proved useful for meta-gradient research, and we are confident that it can serve as a critical auto-differentiation tool for an even broader range of differentiable optimization problems.
We’re enthusiastic about TorchOpt’s potential and are dedicated to its ongoing development and refinement. We welcome community feedback and contributions to help us make TorchOpt even better. Stay tuned for more updates and features in the coming months!
Acknowledgements
- The JAXopt [7] library, with its well-designed APIs for implicit gradient differentiation, has greatly inspired us. Its approach to hardware-accelerated, batchable, and differentiable optimization solutions has offered us significant insights into managing optimization problems effectively.
- Optax [8], with its focus on functional programming and gradient processing, has been a fundamental basis for our work. The manner in which it combines low-level ingredients into custom optimizers has inspired us in designing our own functional APIs, greatly enhancing the efficiency of our project.
- Betty [9], an automatic differentiation library for generalized meta-learning and multilevel optimization, has also been a valuable reference for us. While not directly integrated into our project, its features have offered us useful insights and contributed to the conceptualization and design of features in our own TorchOpt library.
References
[1] Liu, B., Feng, X., Ren, J., Mai, L., Zhu, R., Zhang, H., … & Yang, Y. (2022). A theoretical understanding of gradient bias in meta-reinforcement learning. Advances in Neural Information Processing Systems, 35, 31059–31072.
[2] Finn, C., Abbeel, P., & Levine, S. (2017, July). Model-agnostic meta-learning for fast adaptation of deep networks. In International conference on machine learning (pp. 1126–1135). PMLR.
[3] Hu, Y., Anderson, L., Li, T. M., Sun, Q., Carr, N., Ragan-Kelley, J., & Durand, F. (2019). Difftaichi: Differentiable programming for physical simulation. arXiv preprint arXiv:1910.00935.
[4] Freeman, C. D., Frey, E., Raichuk, A., Girgin, S., Mordatch, I., & Bachem, O. (2021). Brax — A Differentiable Physics Engine for Large Scale Rigid Body Simulation. arXiv preprint arXiv:2106.13281.
[5] Schoenholz, S., & Cubuk, E. D. (2020). Jax md: a framework for differentiable physics. Advances in Neural Information Processing Systems, 33, 11428–11441.
[6] Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2019). Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations. Journal of Computational physics, 378, 686–707.
[7] Blondel, M., Berthet, Q., Cuturi, M., Frostig, R., Hoyer, S., Llinares-López, F., … & Vert, J. P. (2022). Efficient and modular implicit differentiation. Advances in neural information processing systems, 35, 5230–5242.
[8] Babuschkin, I., Baumli, K., Bell, A., Bhupatiraju, S., Bruce, J., Buchlovsky, P., Budden, D., Cai, T., Clark, A., Danihelka, I., Dedieu, A., Fantacci, C., Godwin, J., Jones, C., Hemsley, R., Hennigan, T., Hessel, M., Hou, S., Kapturowski, S., … Viola, F. (2020). The DeepMind JAX Ecosystem. http://github.com/deepmind.
[9] Choe, S. K., Neiswanger, W., Xie, P., & Xing, E. (2022). Betty: An automatic differentiation library for multilevel optimization. arXiv preprint arXiv:2207.02849.