QSeaBattle: Lin Trainable Assisted Imitation Learning (Bootstrap Tutorial)¶
This notebook demonstrates how we bootstrap trainable assisted players by imitation learning from a known-correct classical assisted strategy.
The goal is not to learn a better strategy yet, but to:
- Verify that the Lin neural architecture can exactly reproduce the classical assisted algorithm.
- Produce a stable initialization for later DIAL / DRU / RL training.
- Validate the measurement–shared resource (SR)–combine decomposition.
If this notebook works, the architecture is correct. If it fails, the bug is architectural, not “learning-related”.
Invariant: At the end of this notebook, the neural assisted players must match the classical assisted players bit-for-bit in sample mode.
from __future__ import annotations
import os
import sys
from pathlib import Path
def change_to_repo_root(marker: str = "src") -> None:
"""Change CWD to the repository root (parent of `src`)."""
here = Path.cwd()
for parent in [here] + list(here.parents):
if (parent / marker).is_dir():
os.chdir(parent)
break
change_to_repo_root("src")
sys.path.append("./src")
print("CWD:", Path.cwd())
CWD: c:\Users\nly99857\OneDrive - Philips\SW Projects\QSeaBattle
Imports¶
We reuse the linear trainable assisted stack and utilities for imitation training.
import numpy as np
import tensorflow as tf
from Q_Sea_Battle.game_layout import GameLayout
from Q_Sea_Battle.game_env import GameEnv
from Q_Sea_Battle.tournament import Tournament
from Q_Sea_Battle.lin_trainable_models import LinTrainableAssistedModelA
from Q_Sea_Battle.lin_trainable_models import LinTrainableAssistedModelB
from Q_Sea_Battle.trainable_assisted_players import TrainableAssistedPlayers
from Q_Sea_Battle.lin_trainable_assisted_imitation_utilities import (
generate_measurement_dataset_a,
generate_measurement_dataset_b,
generate_combine_dataset_a,
generate_combine_dataset_b,
to_tf_dataset,
transfer_assisted_model_a_layer_weights,
transfer_assisted_model_b_layer_weights,
train_layer
)
print("TensorFlow:", tf.__version__)
TensorFlow: 2.20.0
Game layout and correlation setting¶
We fix a small field (4×4) and a single communication bit (m=1). This regime is chosen deliberately:
• n² = 16 is the smallest nontrivial power-of-two field • m = 1 matches the theoretical assisted protocol • p_high controls the strength of the shared resource (SR) correlation
Why this matters: The Lin architecture relies on a single shared resource (SR) call per decision. Using the smallest valid field makes debugging and verification feasible, while still exercising the full measurement → correlation → parity pipeline.
Note: Changing n² or m here does not only scale the problem; it changes the required structure of the combine layers.
FIELD_SIZE = 4
COMMS_SIZE = 1
# shared resource (SR) correlation parameter used by your task
P_HIGH = 1.0
# Dataset / training sizes
DATASET_SIZE = 50_000
BATCH_SIZE = 256
EPOCHS_MEAS = 4 # we can use smaller number of epochs for measurement training, since it is an easier task
EPOCHS_COMB = 25
# DIAL/DRU training settings
SR_MODE_BOOTSTRAP_EVAL = "sample"
SR_MODE_DIAL_TRAIN = "expected"
SR_MODE_DIAL_EVAL = "sample"
SEED = 123
tf.random.set_seed(SEED)
np.random.seed(SEED)
# Folders
data_dir = Path("notebooks/data")
models_dir = Path("notebooks/models")
data_dir.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)
n2 = FIELD_SIZE * FIELD_SIZE
print("n2:", n2, "m:", COMMS_SIZE)
n2: 16 m: 1
Generating imitation targets from the classical assisted players¶
In this section we generate datasets using the deterministic classical assisted strategy.
Why imitation first? • The classical assisted algorithm is provably correct. • It defines exact targets for: – measurement choices – shared resource (SR) outcomes – communication bits – shoot decisions • Learning from this data isolates architectural errors from optimisation noise.
Crucially: The neural models are not discovering the strategy here. They are learning to represent it.
Invariant: The generated dataset encodes exactly one shared-resource interaction per decision, and this ordering must be preserved throughout training.
# Build layout for data generation (enemy_probability/channel_noise not used by these generators)
layout = GameLayout(field_size=FIELD_SIZE, comms_size=COMMS_SIZE)
# --- Generate datasets ---
ds_meas_a = generate_measurement_dataset_a(layout, num_samples=DATASET_SIZE, seed=SEED)
ds_comb_a = generate_combine_dataset_a(layout, num_samples=DATASET_SIZE, seed=SEED + 1)
ds_meas_b = generate_measurement_dataset_b(layout, num_samples=DATASET_SIZE, seed=SEED + 2)
ds_comb_b = generate_combine_dataset_b(layout, num_samples=DATASET_SIZE, seed=SEED + 3)
tfds_meas_a = to_tf_dataset(ds_meas_a, x_keys=["field"], y_key="meas_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED)
tfds_comb_a = to_tf_dataset(ds_comb_a, x_keys=["outcomes_a"], y_key="comm_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+1)
tfds_meas_b = to_tf_dataset(ds_meas_b, x_keys=["gun"], y_key="meas_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+2)
tfds_comb_b = to_tf_dataset(ds_comb_b, x_keys=["outcomes_b","comm"], y_key="shoot_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+3)
print("Datasets ready.")
Datasets ready.
Training individual layers by supervised imitation¶
We train the layers in isolation rather than end-to-end.
Why layer-wise training? • The assisted algorithm has a known internal structure. • Measurement layers have local, per-cell targets. • Combine layers implement global parity, which is hard to learn end-to-end. • Separating them stabilizes training and makes failures diagnosable.
Interpretation: Each layer learns a well-defined subroutine of the classical algorithm.
Measurement target: Player A measures type-1 exactly on cells where field == 1. Player B measures type-1 exactly at the gun position.
These targets are not arbitrary: they are the classical measurement rules expressed in neural form.
Combine target: The communication bit (and shoot decision) is the parity (XOR) of the shared-resource outcomes.
This is the nontrivial part of the protocol. If this layer fails to learn parity, the architecture will not scale.
# --- Train layers ---
from Q_Sea_Battle.lin_teacher_layers import LinMeasurementLayerA
from Q_Sea_Battle.lin_teacher_layers import LinMeasurementLayerB
from Q_Sea_Battle.lin_teacher_layers import LinCombineLayerA
from Q_Sea_Battle.lin_teacher_layers import LinCombineLayerB
# --- Build layers ---
n2 = FIELD_SIZE * FIELD_SIZE
model_a = LinTrainableAssistedModelA(
field_size=FIELD_SIZE,
comms_size=COMMS_SIZE,
sr_mode="sample", # evaluation mode
seed=SEED,
p_high=P_HIGH,
)
meas_layer_a = model_a.measure_layer
comb_layer_a = model_a.combine_layer
model_b = LinTrainableAssistedModelB(
field_size=FIELD_SIZE,
comms_size=COMMS_SIZE,
sr_mode="sample", # evaluation mode
seed=SEED,
p_high=P_HIGH,
)
meas_layer_b = model_b.measure_layer
comb_layer_b = model_b.combine_layer
_ = train_layer(meas_layer_a, tfds_meas_a, loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), epochs=EPOCHS_MEAS,
metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
_ = train_layer(comb_layer_a, tfds_comb_a, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), epochs=EPOCHS_COMB,
metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)])
_ = train_layer(meas_layer_b, tfds_meas_b, loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), epochs=EPOCHS_MEAS,
metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
_ = train_layer(comb_layer_b, tfds_comb_b, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), epochs=EPOCHS_COMB,
metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)])
print("Standalone layers trained.")
WARNING:tensorflow:From c:\Users\nly99857\OneDrive - Philips\SW Projects\QSeaBattle\venvs\env_QSeaBattle\Lib\site-packages\keras\src\backend\tensorflow\core.py:232: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. Epoch 1/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 22ms/step - binary_accuracy: 0.8626 - loss: 0.4615 Epoch 2/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9997 - loss: 0.1260 Epoch 3/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 22ms/step - binary_accuracy: 1.0000 - loss: 0.0403 Epoch 4/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 1.0000 - loss: 0.0184 Epoch 1/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 8s 41ms/step - binary_accuracy: 0.5041 - loss: 0.6941 Epoch 2/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 7s 36ms/step - binary_accuracy: 0.5145 - loss: 0.6925 Epoch 3/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.5412 - loss: 0.6883 Epoch 4/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 9s 45ms/step - binary_accuracy: 0.5802 - loss: 0.6769 Epoch 5/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.6340 - loss: 0.6462 Epoch 6/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.6996 - loss: 0.5887 Epoch 7/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 28ms/step - binary_accuracy: 0.7629 - loss: 0.5112 Epoch 8/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8092 - loss: 0.4398 Epoch 9/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8538 - loss: 0.3696 Epoch 10/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.8926 - loss: 0.3026 Epoch 11/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.9175 - loss: 0.2517 Epoch 12/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9368 - loss: 0.2081 Epoch 13/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9502 - loss: 0.1751 Epoch 14/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9631 - loss: 0.1426 Epoch 15/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9693 - loss: 0.1197 Epoch 16/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9736 - loss: 0.1035 Epoch 17/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9786 - loss: 0.0873 Epoch 18/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.9801 - loss: 0.0784 Epoch 19/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9829 - loss: 0.0700 Epoch 20/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9836 - loss: 0.0643 Epoch 21/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9858 - loss: 0.0590 Epoch 22/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 28ms/step - binary_accuracy: 0.9865 - loss: 0.0543 Epoch 23/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9886 - loss: 0.0493 Epoch 24/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 31ms/step - binary_accuracy: 0.9882 - loss: 0.0478 Epoch 25/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9892 - loss: 0.0434 Epoch 1/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.8968 - loss: 0.3909 Epoch 2/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.9442 - loss: 0.1299 Epoch 3/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.9990 - loss: 0.0419 Epoch 4/4 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 1.0000 - loss: 0.0148 Epoch 1/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.5031 - loss: 0.6942 Epoch 2/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.5143 - loss: 0.6929 Epoch 3/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.5190 - loss: 0.6922 Epoch 4/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.5382 - loss: 0.6894 Epoch 5/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.5862 - loss: 0.6776 Epoch 6/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.6343 - loss: 0.6472 Epoch 7/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.6875 - loss: 0.5989 Epoch 8/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.7638 - loss: 0.5178 Epoch 9/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.8164 - loss: 0.4349 Epoch 10/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8534 - loss: 0.3659 Epoch 11/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 28ms/step - binary_accuracy: 0.8822 - loss: 0.3072 Epoch 12/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.9139 - loss: 0.2490 Epoch 13/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9412 - loss: 0.1914 Epoch 14/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9541 - loss: 0.1546 Epoch 15/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9629 - loss: 0.1296 Epoch 16/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9689 - loss: 0.1117 Epoch 17/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9731 - loss: 0.0979 Epoch 18/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9747 - loss: 0.0888 Epoch 19/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9780 - loss: 0.0788 Epoch 20/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9805 - loss: 0.0711 Epoch 21/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9808 - loss: 0.0668 Epoch 22/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9825 - loss: 0.0615 Epoch 23/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9837 - loss: 0.0575 Epoch 24/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9844 - loss: 0.0536 Epoch 25/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9862 - loss: 0.0501 Standalone layers trained.
Assembling Model A and Model B¶
After training the components, we assemble the full LinTrainableAssistedModelA and LinTrainableAssistedModelB.
Why assembly after training? • It enforces a strict separation between structure and optimisation. • It guarantees that A and B use the same shared resource (SR) resource. • It mirrors the exact dataflow of the classical assisted players.
At this point, the models are structurally complete.
Warning: If Model A and Model B are not perfectly aligned on shared layers, any apparent learning success is meaningless.
# --- Install into full models ---
model_a = LinTrainableAssistedModelA(field_size=FIELD_SIZE, comms_size=COMMS_SIZE, sr_mode=SR_MODE_BOOTSTRAP_EVAL, seed=SEED, p_high=P_HIGH)
model_b = LinTrainableAssistedModelB(field_size=FIELD_SIZE, comms_size=COMMS_SIZE, sr_mode=SR_MODE_BOOTSTRAP_EVAL, seed=SEED, p_high=P_HIGH)
# Build models (required before weight transfer)
_ = model_a(tf.zeros((1, n2), tf.float32))
_dummy_gun = tf.zeros((1, n2), tf.float32)
_dummy_comm = tf.zeros((1, COMMS_SIZE), tf.float32)
_dummy_prev_meas_list = [tf.zeros((1, n2), tf.float32)]
_dummy_prev_out_list = [tf.zeros((1, n2), tf.float32)]
_ = model_b([_dummy_gun, _dummy_comm, _dummy_prev_meas_list, _dummy_prev_out_list])
transfer_assisted_model_a_layer_weights(meas_layer_a, comb_layer_a, model_a)
transfer_assisted_model_b_layer_weights(meas_layer_b, comb_layer_b, model_b)
print("Installed weights into model_a/model_b")
print("model_a sr_mode:", model_a.sr_layer.mode)
print("model_b sr_mode:", model_b.sr_layer.mode)
Installed weights into model_a/model_b model_a sr_mode: sample model_b sr_mode: sample
Verification: neural vs classical assisted players¶
This is not a performance benchmark. It is a correctness check.
We compare: • Classical AssistedPlayers • Neural TrainableAssistedPlayers (sample mode)
Success criterion: Identical win rates and identical decision statistics within sampling noise.
Failure here means: • A broken measurement ordering • Incorrect handling of shared resource (SR) • Or a mismatch in combine logic
# Force SR sample explicitly for evaluation
model_a.sr_layer.mode = "sample"
model_b.sr_layer.mode = "sample"
players = TrainableAssistedPlayers(layout, model_a=model_a, model_b=model_b)
layout_eval = GameLayout(
field_size=FIELD_SIZE,
comms_size=COMMS_SIZE,
enemy_probability=0.5,
channel_noise=0.0,
number_of_games_in_tournament=2_000,
)
env = GameEnv(layout_eval)
t = Tournament(env, players, layout_eval)
log = t.tournament()
mean_reward, std_err = log.outcome()
print(f"Bootstrap tournament over {layout_eval.number_of_games_in_tournament}: {mean_reward:.4f} ± {std_err:.4f}")
Bootstrap tournament over 2000: 1.0000 ± 0.0000
Save weights¶
We save weights (not full .keras serialization) to avoid custom-object config issues.
model_a_path = models_dir / f"lin_model_a_bootstrap_f{FIELD_SIZE}_m{COMMS_SIZE}_p{P_HIGH:.2f}.weights.h5"
model_b_path = models_dir / f"lin_model_b_bootstrap_f{FIELD_SIZE}_m{COMMS_SIZE}_p{P_HIGH:.2f}.weights.h5"
model_a.save_weights(model_a_path)
model_b.save_weights(model_b_path)
print("Saved weights:")
print(" -", model_a_path)
print(" -", model_b_path)
Saved weights: - notebooks\models\lin_model_a_bootstrap_f4_m1_p1.00.weights.h5 - notebooks\models\lin_model_b_bootstrap_f4_m1_p1.00.weights.h5