API reference for utility functions and classes in Conformal Q-Learning
This page provides details on utility functions and helper classes used in the Conformal Q-Learning implementation.
The Replay Buffer is used to store and sample experiences for offline learning.
class ReplayBuffer:
def __init__(self):
self.buffer = []
def store(self, transition):
self.buffer.append(transition)
def sample(self, batch_size):
return random.sample(self.buffer, batch_size)
The Actor network is used to learn the policy in the SAC algorithm.
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, min_log_std=-10, max_log_std=2):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.mu_head = nn.Linear(256, action_dim)
self.log_std_head = nn.Linear(256, action_dim)
self.min_log_std = min_log_std
self.max_log_std = max_log_std
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mu = self.mu_head(x)
log_std = F.relu(self.log_std_head(x))
log_std = torch.clamp(log_std, self.min_log_std, self.max_log_std)
return mu, log_std
The Critic network is used to estimate the Q-value function.
class Critic(nn.Module):
def __init__(self, state_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
Computes the conformal interval for uncertainty estimation.
def compute_conformal_interval(self, alpha=0.1):
if len(self.calibration_set) == 0:
return 0.0
with torch.no_grad():
max_samples = 1000
sample_indices = np.random.choice(len(self.calibration_set),
min(len(self.calibration_set), max_samples),
replace=False)
sampled = [self.calibration_set[i] for i in sample_indices]
bn_s = torch.tensor(np.array([t.s for t in sampled]), dtype=torch.float32, device=device)
bn_a = torch.tensor(np.array([t.a for t in sampled]), dtype=torch.float32, device=device)
bn_r = torch.tensor(np.array([t.r for t in sampled]), dtype=torch.float32, device=device).view(-1, 1)
bn_s_ = torch.tensor(np.array([t.s_ for t in sampled]), dtype=torch.float32, device=device)
bn_d = torch.tensor(np.array([t.d for t in sampled]), dtype=torch.float32, device=device).view(-1, 1)
q_values = self.Q_net1(bn_s, bn_a).squeeze()
gamma = self.config.get('gamma', 0.99)
y_values = bn_r + (1 - bn_d) * gamma * self.Target_value_net(bn_s_).squeeze()
self.q_target_ema = 0.95 * self.q_target_ema + 0.05 * y_values.mean().item()
residuals = torch.abs(q_values - self.q_target_ema)
q_alpha = torch.quantile(residuals, 1 - alpha).item()
return q_alpha
Evaluates the current policy on the environment.
def evaluate(self, eval_episodes: int = 5) -> float:
total_reward = 0.0
for _ in range(eval_episodes):
state = self.env.reset()
done = False
ep_reward = 0.0
while not done:
action = self.agent.select_action(state)
state, reward, done, _ = self.env.step(action.astype(np.float32))
ep_reward += reward
total_reward += ep_reward
avg_reward = total_reward / eval_episodes
return avg_reward
The implementation uses TensorBoardX for logging training progress and results.
from tensorboardX import SummaryWriter
self.writer = SummaryWriter('./exp-SAC_dual_Q_network')
# Example usage in the update method
self.writer.add_scalar('Loss/V_loss', V_loss.item(), self.num_training)
self.writer.add_scalar('Loss/Q1_loss', Q1_loss.item(), self.num_training)
self.writer.add_scalar('Loss/Q2_loss', Q2_loss.item(), self.num_training)
self.writer.add_scalar('Loss/policy_loss', pi_loss.item(), self.num_training)
self.writer.add_scalar('Uncertainty Update', self.q_alpha, self.num_training)
For more information on the core algorithm, see the Conformal Q-Learning page.