API reference for the main Conformal Q-Learning implementation
The SACAgent class implements Conformal Q-Learning, combining Soft Actor-Critic (SAC) with Conformal Prediction for offline reinforcement learning.
The SACAgent class is a high-level wrapper for the Conformal SAC agent, providing an easy-to-use interface for training and evaluation.
class SACAgent:
def __init__(self, env_name: str, offline: bool = True, iteration: int = 100000, seed: int = 1, **config):
# Initialize the Conformal Q-Learning agent
...env_namestringName of the Gym (or D4RL) environment.
offlineboolean, default: TrueIf True, use an offline dataset from D4RL.
iterationinteger, default: 100000Number of training iterations.
seedinteger, default: 1Random seed for reproducibility.
**configdictAdditional hyperparameters for the underlying SAC agent.
Runs the main training loop for the Conformal Q-Learning algorithm.
def train(self):
# Main training loop
...Evaluates the current policy on the environment.
def evaluate(self, eval_episodes: int = 5) -> float:
# Evaluate the policy
...The SAC class implements the core Soft Actor-Critic algorithm with Conformal Prediction integration.
class SAC:
def __init__(self, state_dim, action_dim, config):
# Initialize the SAC agent
...select_action(state)Selects an action given a state.
store(s, a, r, s_, d)Stores a transition in the replay buffer.
update()Performs one update step using sampled transitions.
compute_conformal_interval(alpha=0.1)Computes the conformal interval for uncertainty estimation.
from conformal_sac.agent_wrapper import SACAgent
agent = SACAgent(
env_name="halfcheetah-medium-expert",
offline=True,
iteration=100000,
seed=42,
learning_rate=3e-4,
gamma=0.99,
tau=0.005,
batch_size=256,
log_interval=2000,
alpha_q=100,
q_alpha_update_freq=50
)
agent.train()
score = agent.evaluate(eval_episodes=5)
print(f"Final evaluation score: {score}")
For more details on configuration options, see the Configuration page.