Conformal Q-Learning

API reference for the main Conformal Q-Learning implementation

The SACAgent class implements Conformal Q-Learning, combining Soft Actor-Critic (SAC) with Conformal Prediction for offline reinforcement learning.

SACAgent Class

The SACAgent class is a high-level wrapper for the Conformal SAC agent, providing an easy-to-use interface for training and evaluation.

Constructor

class SACAgent:
    def __init__(self, env_name: str, offline: bool = True, iteration: int = 100000, seed: int = 1, **config):
        # Initialize the Conformal Q-Learning agent
        ...

Parameters

  • env_namestring

    Name of the Gym (or D4RL) environment.

  • offlineboolean, default: True

    If True, use an offline dataset from D4RL.

  • iterationinteger, default: 100000

    Number of training iterations.

  • seedinteger, default: 1

    Random seed for reproducibility.

  • **configdict

    Additional hyperparameters for the underlying SAC agent.

Methods

train()

Runs the main training loop for the Conformal Q-Learning algorithm.

def train(self):
    # Main training loop
    ...

evaluate(eval_episodes: int = 5) → float

Evaluates the current policy on the environment.

def evaluate(self, eval_episodes: int = 5) -> float:
    # Evaluate the policy
    ...

SAC Class

The SAC class implements the core Soft Actor-Critic algorithm with Conformal Prediction integration.

Constructor

class SAC:
    def __init__(self, state_dim, action_dim, config):
        # Initialize the SAC agent
        ...

Key Methods

  • select_action(state)

    Selects an action given a state.

  • store(s, a, r, s_, d)

    Stores a transition in the replay buffer.

  • update()

    Performs one update step using sampled transitions.

  • compute_conformal_interval(alpha=0.1)

    Computes the conformal interval for uncertainty estimation.

Example Usage

from conformal_sac.agent_wrapper import SACAgent

agent = SACAgent(
    env_name="halfcheetah-medium-expert",
    offline=True,
    iteration=100000,
    seed=42,
    learning_rate=3e-4,
    gamma=0.99,
    tau=0.005,
    batch_size=256,
    log_interval=2000,
    alpha_q=100,
    q_alpha_update_freq=50
)

agent.train()
score = agent.evaluate(eval_episodes=5)
print(f"Final evaluation score: {score}")

For more details on configuration options, see the Configuration page.