A coding implementation to train safety-critical reinforcement learning agents offline using d3rlpy and conservative Q-learning with fixed historical data

by
0 comments
A coding implementation to train safety-critical reinforcement learning agents offline using d3rlpy and conservative Q-learning with fixed historical data

In this tutorial, we build a security-critical reinforcement learning pipeline that learns from fully deterministic, offline data instead of live exploration. We design a custom environment, generate a behavior dataset from a restricted policy, and then train both a behavior cloning baseline and a conservative Q-learning agent using d3rlpy. By structuring the workflow around offline datasets, careful evaluation, and conservative learning objectives, we demonstrate how robust decision-making policies can be trained in settings where unsupervised exploration is not an option. check it out full code here.

!pip -q install -U "d3rlpy" "gymnasium" "numpy" "torch" "matplotlib" "scikit-learn"


import os
import time
import random
import inspect
import numpy as np
import matplotlib.pyplot as plt


import gymnasium as gym
from gymnasium import spaces


import torch
import d3rlpy




SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)




def pick_device():
   if torch.cuda.is_available():
       return "cuda:0"
   return "cpu"




DEVICE = pick_device()
print("d3rlpy:", getattr(d3rlpy, "__version__", "unknown"), "| torch:", torch.__version__, "| device:", DEVICE)




def make_config(cls, **kwargs):
   sig = inspect.signature(cls.__init__)
   allowed = set(sig.parameters.keys())
   allowed.discard("self")
   filtered = {k: v for k, v in kwargs.items() if k in allowed}
   return cls(**filtered)

We set up the environment by installing dependencies, importing libraries, and fine-tuning random seeds for reproducibility. We locate and configure compute tools to ensure consistent performance across all systems. We also define a utility to safely create configuration objects across different d3rlpy versions. check it out full code here.

class SafetyCriticalGridWorld(gym.Env):
   metadata = {"render_modes": ()}


   def __init__(
       self,
       size=15,
       max_steps=80,
       hazard_coords=None,
       start=(0, 0),
       goal=None,
       slip_prob=0.05,
       seed=0,
   ):
       super().__init__()
       self.size = int(size)
       self.max_steps = int(max_steps)
       self.start = tuple(start)
       self.goal = tuple(goal) if goal is not None else (self.size - 1, self.size - 1)
       self.slip_prob = float(slip_prob)


       if hazard_coords is None:
           hz = set()
           rng = np.random.default_rng(seed)
           for _ in range(max(1, self.size // 2)):
               x = rng.integers(2, self.size - 2)
               y = rng.integers(2, self.size - 2)
               hz.add((int(x), int(y)))
           self.hazards = hz
       else:
           self.hazards = set(tuple(x) for x in hazard_coords)


       self.action_space = spaces.Discrete(4)
       self.observation_space = spaces.Box(low=0.0, high=float(self.size - 1), shape=(2,), dtype=np.float32)


       self._rng = np.random.default_rng(seed)
       self._pos = None
       self._t = 0


   def reset(self, *, seed=None, options=None):
       if seed is not None:
           self._rng = np.random.default_rng(seed)
       self._pos = (int(self.start(0)), int(self.start(1)))
       self._t = 0
       obs = np.array(self._pos, dtype=np.float32)
       return obs, {}


   def _clip(self):
       self._pos(0) = int(np.clip(self._pos(0), 0, self.size - 1))
       self._pos(1) = int(np.clip(self._pos(1), 0, self.size - 1))


   def step(self, action):
       self._t += 1


       a = int(action)
       if self._rng.random() < self.slip_prob:
           a = int(self._rng.integers(0, 4))


       if a == 0:
           self._pos(1) += 1
       elif a == 1:
           self._pos(0) += 1
       elif a == 2:
           self._pos(1) -= 1
       elif a == 3:
           self._pos(0) -= 1


       self._clip()


       x, y = int(self._pos(0)), int(self._pos(1))
       terminated = False
       truncated = self._t >= self.max_steps


       reward = -1.0


       if (x, y) in self.hazards:
           reward = -100.0
           terminated = True


       if (x, y) == self.goal:
           reward = +50.0
           terminated = True


       obs = np.array((x, y), dtype=np.float32)
       return obs, float(reward), terminated, truncated, {}

We define a safety-critical gridworld environment with threats, terminal conditions, and stochastic transitions. We code in penalties for unsafe situations and rewards for successful task completion. We ensure that the environment tightly controls mobility to reflect real-world security constraints. check it out full code here.

def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
   x, y = int(obs(0)), int(obs(1))
   gx, gy = env.goal


   preferred = ()
   if gx > x:
       preferred.append(1)
   elif gx < x:
       preferred.append(3)
   if gy > y:
       preferred.append(0)
   elif gy < y:
       preferred.append(2)


   if len(preferred) == 0:
       preferred = (int(env._rng.integers(0, 4)))


   if env._rng.random() < epsilon:
       return int(env._rng.integers(0, 4))


   candidates = ()
   for a in preferred:
       nx, ny = x, y
       if a == 0:
           ny += 1
       elif a == 1:
           nx += 1
       elif a == 2:
           ny -= 1
       elif a == 3:
           nx -= 1
       nx = int(np.clip(nx, 0, env.size - 1))
       ny = int(np.clip(ny, 0, env.size - 1))
       if (nx, ny) not in env.hazards:
           candidates.append(a)


   if len(candidates) == 0:
       return preferred(0)
   return int(random.choice(candidates))




def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
   episodes = ()
   for i in range(n_episodes):
       obs, _ = env.reset(seed=int(seed + i))
       obs_list = ()
       act_list = ()
       rew_list = ()
       done_list = ()


       done = False
       while not done:
           a = safe_behavior_policy(obs, env, epsilon=epsilon)
           nxt, r, terminated, truncated, _ = env.step(a)
           done = bool(terminated or truncated)


           obs_list.append(np.array(obs, dtype=np.float32))
           act_list.append(np.array((a), dtype=np.int64))
           rew_list.append(np.array((r), dtype=np.float32))
           done_list.append(np.array((1.0 if done else 0.0), dtype=np.float32))


           obs = nxt


       episodes.append(
           {
               "observations": np.stack(obs_list, axis=0),
               "actions": np.stack(act_list, axis=0),
               "rewards": np.stack(rew_list, axis=0),
               "terminals": np.stack(done_list, axis=0),
           }
       )
   return episodes




def build_mdpdataset(episodes):
   obs = np.concatenate((ep("observations") for ep in episodes), axis=0).astype(np.float32)
   acts = np.concatenate((ep("actions") for ep in episodes), axis=0).astype(np.int64)
   rews = np.concatenate((ep("rewards") for ep in episodes), axis=0).astype(np.float32)
   terms = np.concatenate((ep("terminals") for ep in episodes), axis=0).astype(np.float32)


   if hasattr(d3rlpy, "dataset") and hasattr(d3rlpy.dataset, "MDPDataset"):
       return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=terms)


   raise RuntimeError("d3rlpy.dataset.MDPDataset not found. Upgrade d3rlpy.")

We design a restricted behavior policy that generates offline data without risky exploration. We apply this strategy to collect trajectories and structure them into episodes. We then convert these episodes into a format compatible with d3rlpy’s offline learning API. check it out full code here.

def _get_episodes_from_dataset(dataset):
   if hasattr(dataset, "episodes") and dataset.episodes is not None:
       return dataset.episodes
   if hasattr(dataset, "get_episodes"):
       return dataset.get_episodes()
   raise AttributeError("Could not find episodes in dataset (d3rlpy version mismatch).")




def _iter_all_observations(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       if obs is None:
           continue
       for o in obs:
           yield o




def _iter_all_transitions(dataset):
   for ep in _get_episodes_from_dataset(dataset):
       obs = getattr(ep, "observations", None)
       acts = getattr(ep, "actions", None)
       rews = getattr(ep, "rewards", None)
       if obs is None or acts is None:
           continue
       n = min(len(obs), len(acts))
       for i in range(n):
           o = obs(i)
           a = acts(i)
           r = rews(i) if rews is not None and i < len(rews) else None
           yield o, a, r




def visualize_dataset(dataset, env, title="Offline Dataset"):
   state_visits = np.zeros((env.size, env.size), dtype=np.float32)
   for obs in _iter_all_observations(dataset):
       x, y = int(obs(0)), int(obs(1))
       x = int(np.clip(x, 0, env.size - 1))
       y = int(np.clip(y, 0, env.size - 1))
       state_visits(y, x) += 1


   plt.figure(figsize=(6, 5))
   plt.imshow(state_visits, origin="lower")
   plt.colorbar(label="Visits")
   plt.scatter((env.start(0)), (env.start(1)), marker="o", label="start")
   plt.scatter((env.goal(0)), (env.goal(1)), marker="*", label="goal")
   if len(env.hazards) > 0:
       hz = np.array(list(env.hazards), dtype=np.int32)
       plt.scatter(hz(:, 0), hz(:, 1), marker="x", label="hazards")
   plt.title(f"{title} — State visitation")
   plt.xlabel("x")
   plt.ylabel("y")
   plt.legend()
   plt.show()


   rewards = ()
   for _, _, r in _iter_all_transitions(dataset):
       if r is not None:
           rewards.append(float(r))
   if len(rewards) > 0:
       plt.figure(figsize=(6, 4))
       plt.hist(rewards, bins=60)
       plt.title(f"{title} — Reward distribution")
       plt.xlabel("reward")
       plt.ylabel("count")
       plt.show()

We implement dataset utilities that correctly iterate through episodes instead of assuming flat arrays. We visualize state visitation to understand coverage and data bias in offline datasets. We also analyze the reward distribution to observe the learning signals available to the agent. check it out full code here.

def rollout_eval(env, algo, n_episodes=25, seed=0):
   returns = ()
   lengths = ()
   hazard_hits = 0
   goal_hits = 0


   for i in range(n_episodes):
       obs, _ = env.reset(seed=seed + i)
       done = False
       total = 0.0
       steps = 0
       while not done:
           a = int(algo.predict(np.asarray(obs, dtype=np.float32)(None, ...))(0))
           obs, r, terminated, truncated, _ = env.step(a)
           total += float(r)
           steps += 1
           done = bool(terminated or truncated)
           if terminated:
               x, y = int(obs(0)), int(obs(1))
               if (x, y) in env.hazards:
                   hazard_hits += 1
               if (x, y) == env.goal:
                   goal_hits += 1


       returns.append(total)
       lengths.append(steps)


   return {
       "return_mean": float(np.mean(returns)),
       "return_std": float(np.std(returns)),
       "len_mean": float(np.mean(lengths)),
       "hazard_rate": float(hazard_hits / max(1, n_episodes)),
       "goal_rate": float(goal_hits / max(1, n_episodes)),
       "returns": np.asarray(returns, dtype=np.float32),
   }




def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
   rng = np.random.default_rng(seed)
   obs_all = ()
   act_all = ()
   for o, a, _ in _iter_all_transitions(dataset):
       obs_all.append(np.asarray(o, dtype=np.float32))
       act_all.append(int(np.asarray(a).reshape(-1)(0)))
       if len(obs_all) >= 80_000:
           break


   obs_all = np.stack(obs_all, axis=0)
   act_all = np.asarray(act_all, dtype=np.int64)


   idx = rng.choice(len(obs_all), size=min(sample_obs, len(obs_all)), replace=False)
   obs_probe = obs_all(idx)
   act_probe_data = act_all(idx)
   act_probe_pi = algo.predict(obs_probe).astype(np.int64)


   mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
   return float(mismatch.mean())




def create_discrete_bc(device):
   if hasattr(d3rlpy.algos, "DiscreteBCConfig"):
       cls = d3rlpy.algos.DiscreteBCConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           batch_size=256,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteBC"):
       return d3rlpy.algos.DiscreteBC()
   raise RuntimeError("DiscreteBC not available in this d3rlpy version.")




def create_discrete_cql(device, conservative_weight=6.0):
   if hasattr(d3rlpy.algos, "DiscreteCQLConfig"):
       cls = d3rlpy.algos.DiscreteCQLConfig
       cfg = make_config(
           cls,
           learning_rate=3e-4,
           actor_learning_rate=3e-4,
           critic_learning_rate=3e-4,
           temp_learning_rate=3e-4,
           alpha_learning_rate=3e-4,
           batch_size=256,
           conservative_weight=float(conservative_weight),
           n_action_samples=10,
           rollout_interval=0,
       )
       return cfg.create(device=device)
   if hasattr(d3rlpy.algos, "DiscreteCQL"):
       algo = d3rlpy.algos.DiscreteCQL()
       if hasattr(algo, "conservative_weight"):
           try:
               algo.conservative_weight = float(conservative_weight)
           except Exception:
               pass
       return algo
   raise RuntimeError("DiscreteCQL not available in this d3rlpy version.")

We define controlled evaluation routines to measure policy performance without uncontrolled exploration. We calculate return and security metrics including risk and target rates. We also introduce a mismatch diagnostic to measure how often learned actions deviate from dataset behavior. check it out full code here.

def main():
   env = SafetyCriticalGridWorld(
       size=15,
       max_steps=80,
       slip_prob=0.05,
       seed=SEED,
   )


   raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
   dataset = build_mdpdataset(raw_eps)


   print("dataset built:", type(dataset).__name__)
   visualize_dataset(dataset, env, title="Behavior Dataset (Offline)")


   bc = create_discrete_bc(DEVICE)
   cql = create_discrete_cql(DEVICE, conservative_weight=6.0)


   print("nTraining Discrete BC (offline)...")
   t0 = time.time()
   bc.fit(
       dataset,
       n_steps=25_000,
       n_steps_per_epoch=2_500,
       experiment_name="grid_bc_offline",
   )
   print("BC train sec:", round(time.time() - t0, 2))


   print("nTraining Discrete CQL (offline)...")
   t0 = time.time()
   cql.fit(
       dataset,
       n_steps=80_000,
       n_steps_per_epoch=8_000,
       experiment_name="grid_cql_offline",
   )
   print("CQL train sec:", round(time.time() - t0, 2))


   print("nControlled online evaluation (small number of rollouts):")
   bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
   cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)


   print("BC :", {k: v for k, v in bc_metrics.items() if k != "returns"})
   print("CQL:", {k: v for k, v in cql_metrics.items() if k != "returns"})


   print("nOOD-ish diagnostic (policy action mismatch vs data action at same states):")
   bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
   cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
   print("BC mismatch rate :", bc_mismatch)
   print("CQL mismatch rate:", cql_mismatch)


   plt.figure(figsize=(6, 4))
   labels = ("BC", "CQL")
   means = (bc_metrics("return_mean"), cql_metrics("return_mean"))
   stds = (bc_metrics("return_std"), cql_metrics("return_std"))
   plt.bar(labels, means, yerr=stds)
   plt.ylabel("Return")
   plt.title("Online Rollout Return (Controlled)")
   plt.show()


   plt.figure(figsize=(6, 4))
   plt.plot(np.sort(bc_metrics("returns")), label="BC")
   plt.plot(np.sort(cql_metrics("returns")), label="CQL")
   plt.xlabel("Episode (sorted)")
   plt.ylabel("Return")
   plt.title("Return Distribution (Sorted)")
   plt.legend()
   plt.show()


   out_dir = "/content/offline_rl_artifacts"
   os.makedirs(out_dir, exist_ok=True)
   bc_path = os.path.join(out_dir, "grid_bc_policy.pt")
   cql_path = os.path.join(out_dir, "grid_cql_policy.pt")


   if hasattr(bc, "save_policy"):
       bc.save_policy(bc_path)
       print("Saved BC policy:", bc_path)
   if hasattr(cql, "save_policy"):
       cql.save_policy(cql_path)
       print("Saved CQL policy:", cql_path)


   print("nDone.")




if __name__ == "__main__":
   main()

We train both behavioral cloning and conservative Q-learning agents entirely with offline data. We compare their performance using controlled rollouts and diagnostic metrics. We finalize the workflow by saving the trained policies and summarizing the security-aware learning results.

In conclusion, we demonstrated that Conservative Q-learning provides a more reliable policy than simple imitation when learning from historical data in security-sensitive environments. By comparing offline training outcomes, controlled online assessments, and action-delivery mismatches, we illustrate how conservatism helps reduce risky, out-of-delivery behavior. Overall, we have presented a complete, reproducible offline RL workflow that we can extend to more complex domains such as robotics, healthcare, or finance without compromising security.


check it out full code here. Also, feel free to follow us Twitter And don’t forget to join us 100k+ ml subreddit and subscribe our newsletter. wait! Are you on Telegram? Now you can also connect with us on Telegram.


Related Articles

Leave a Comment