QSeaBattle: Lin Trainable Assisted Imitation Learning (Bootstrap Tutorial)¶

This notebook demonstrates how we bootstrap trainable assisted players by imitation learning from a known-correct classical assisted strategy.

The goal is not to learn a better strategy yet, but to:

  1. Verify that the Lin neural architecture can exactly reproduce the classical assisted algorithm.
  2. Produce a stable initialization for later DIAL / DRU / RL training.
  3. Validate the measurement–shared resource (SR)–combine decomposition.

If this notebook works, the architecture is correct. If it fails, the bug is architectural, not “learning-related”.

Invariant: At the end of this notebook, the neural assisted players must match the classical assisted players bit-for-bit in sample mode.

In [2]:
from __future__ import annotations

import os
import sys
from pathlib import Path

def change_to_repo_root(marker: str = "src") -> None:
    """Change CWD to the repository root (parent of `src`)."""
    here = Path.cwd()
    for parent in [here] + list(here.parents):
        if (parent / marker).is_dir():
            os.chdir(parent)
            break

change_to_repo_root("src")
sys.path.append("./src")

print("CWD:", Path.cwd())
CWD: c:\Users\nly99857\OneDrive - Philips\SW Projects\QSeaBattle

Imports¶

We reuse the linear trainable assisted stack and utilities for imitation training.

In [3]:
import numpy as np
import tensorflow as tf

from Q_Sea_Battle.game_layout import GameLayout
from Q_Sea_Battle.game_env import GameEnv
from Q_Sea_Battle.tournament import Tournament

from Q_Sea_Battle.lin_trainable_models import LinTrainableAssistedModelA
from Q_Sea_Battle.lin_trainable_models import LinTrainableAssistedModelB
from Q_Sea_Battle.trainable_assisted_players import TrainableAssistedPlayers

from Q_Sea_Battle.lin_trainable_assisted_imitation_utilities import (
    generate_measurement_dataset_a,
    generate_measurement_dataset_b,
    generate_combine_dataset_a,
    generate_combine_dataset_b,
    to_tf_dataset,
    transfer_assisted_model_a_layer_weights,
    transfer_assisted_model_b_layer_weights,
    train_layer
)

print("TensorFlow:", tf.__version__)
TensorFlow: 2.20.0

Game layout and correlation setting¶

We fix a small field (4×4) and a single communication bit (m=1). This regime is chosen deliberately:

• n² = 16 is the smallest nontrivial power-of-two field • m = 1 matches the theoretical assisted protocol • p_high controls the strength of the shared resource (SR) correlation

Why this matters: The Lin architecture relies on a single shared resource (SR) call per decision. Using the smallest valid field makes debugging and verification feasible, while still exercising the full measurement → correlation → parity pipeline.

Note: Changing n² or m here does not only scale the problem; it changes the required structure of the combine layers.

In [4]:
FIELD_SIZE = 4
COMMS_SIZE = 1

# shared resource (SR) correlation parameter used by your task
P_HIGH = 1.0

# Dataset / training sizes
DATASET_SIZE = 50_000
BATCH_SIZE = 256
EPOCHS_MEAS = 4 # we can use smaller number of epochs for measurement training, since it is an easier task
EPOCHS_COMB = 25

# DIAL/DRU training settings
SR_MODE_BOOTSTRAP_EVAL = "sample"
SR_MODE_DIAL_TRAIN = "expected"
SR_MODE_DIAL_EVAL = "sample"

SEED = 123
tf.random.set_seed(SEED)
np.random.seed(SEED)

# Folders
data_dir = Path("notebooks/data")
models_dir = Path("notebooks/models")
data_dir.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

n2 = FIELD_SIZE * FIELD_SIZE
print("n2:", n2, "m:", COMMS_SIZE)
n2: 16 m: 1

Generating imitation targets from the classical assisted players¶

In this section we generate datasets using the deterministic classical assisted strategy.

Why imitation first? • The classical assisted algorithm is provably correct. • It defines exact targets for: – measurement choices – shared resource (SR) outcomes – communication bits – shoot decisions • Learning from this data isolates architectural errors from optimisation noise.

Crucially: The neural models are not discovering the strategy here. They are learning to represent it.

Invariant: The generated dataset encodes exactly one shared-resource interaction per decision, and this ordering must be preserved throughout training.

In [5]:
# Build layout for data generation (enemy_probability/channel_noise not used by these generators)
layout = GameLayout(field_size=FIELD_SIZE, comms_size=COMMS_SIZE)

# --- Generate datasets ---
ds_meas_a = generate_measurement_dataset_a(layout, num_samples=DATASET_SIZE, seed=SEED)
ds_comb_a = generate_combine_dataset_a(layout, num_samples=DATASET_SIZE, seed=SEED + 1)
ds_meas_b = generate_measurement_dataset_b(layout, num_samples=DATASET_SIZE, seed=SEED + 2)
ds_comb_b = generate_combine_dataset_b(layout, num_samples=DATASET_SIZE, seed=SEED + 3)

tfds_meas_a = to_tf_dataset(ds_meas_a, x_keys=["field"], y_key="meas_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED)
tfds_comb_a = to_tf_dataset(ds_comb_a, x_keys=["outcomes_a"], y_key="comm_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+1)
tfds_meas_b = to_tf_dataset(ds_meas_b, x_keys=["gun"], y_key="meas_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+2)
tfds_comb_b = to_tf_dataset(ds_comb_b, x_keys=["outcomes_b","comm"], y_key="shoot_target", batch_size=BATCH_SIZE, shuffle=True, seed=SEED+3)

print("Datasets ready.")
Datasets ready.

Training individual layers by supervised imitation¶

We train the layers in isolation rather than end-to-end.

Why layer-wise training? • The assisted algorithm has a known internal structure. • Measurement layers have local, per-cell targets. • Combine layers implement global parity, which is hard to learn end-to-end. • Separating them stabilizes training and makes failures diagnosable.

Interpretation: Each layer learns a well-defined subroutine of the classical algorithm.

Measurement target: Player A measures type-1 exactly on cells where field == 1. Player B measures type-1 exactly at the gun position.

These targets are not arbitrary: they are the classical measurement rules expressed in neural form.

Combine target: The communication bit (and shoot decision) is the parity (XOR) of the shared-resource outcomes.

This is the nontrivial part of the protocol. If this layer fails to learn parity, the architecture will not scale.

In [6]:
# --- Train layers ---
from Q_Sea_Battle.lin_teacher_layers import LinMeasurementLayerA
from Q_Sea_Battle.lin_teacher_layers import LinMeasurementLayerB
from Q_Sea_Battle.lin_teacher_layers import LinCombineLayerA
from Q_Sea_Battle.lin_teacher_layers import LinCombineLayerB

# --- Build layers ---
n2 = FIELD_SIZE * FIELD_SIZE
model_a = LinTrainableAssistedModelA(
    field_size=FIELD_SIZE,
    comms_size=COMMS_SIZE,
    sr_mode="sample",   # evaluation mode
    seed=SEED,
    p_high=P_HIGH,
)
meas_layer_a = model_a.measure_layer
comb_layer_a = model_a.combine_layer

model_b = LinTrainableAssistedModelB(
    field_size=FIELD_SIZE,
    comms_size=COMMS_SIZE,
    sr_mode="sample",   # evaluation mode
    seed=SEED,
    p_high=P_HIGH,
)
meas_layer_b = model_b.measure_layer
comb_layer_b = model_b.combine_layer

_ = train_layer(meas_layer_a, tfds_meas_a, loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), epochs=EPOCHS_MEAS,
                metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
_ = train_layer(comb_layer_a, tfds_comb_a, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), epochs=EPOCHS_COMB,
                metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)])

_ = train_layer(meas_layer_b, tfds_meas_b, loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), epochs=EPOCHS_MEAS,
                metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.5)])
_ = train_layer(comb_layer_b, tfds_comb_b, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), epochs=EPOCHS_COMB,
                metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)])

print("Standalone layers trained.")
WARNING:tensorflow:From c:\Users\nly99857\OneDrive - Philips\SW Projects\QSeaBattle\venvs\env_QSeaBattle\Lib\site-packages\keras\src\backend\tensorflow\core.py:232: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

Epoch 1/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 22ms/step - binary_accuracy: 0.8626 - loss: 0.4615
Epoch 2/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9997 - loss: 0.1260
Epoch 3/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 22ms/step - binary_accuracy: 1.0000 - loss: 0.0403
Epoch 4/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 1.0000 - loss: 0.0184
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 8s 41ms/step - binary_accuracy: 0.5041 - loss: 0.6941
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 7s 36ms/step - binary_accuracy: 0.5145 - loss: 0.6925
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.5412 - loss: 0.6883
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 9s 45ms/step - binary_accuracy: 0.5802 - loss: 0.6769
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.6340 - loss: 0.6462
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.6996 - loss: 0.5887
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 28ms/step - binary_accuracy: 0.7629 - loss: 0.5112
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8092 - loss: 0.4398
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8538 - loss: 0.3696
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.8926 - loss: 0.3026
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 27ms/step - binary_accuracy: 0.9175 - loss: 0.2517
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9368 - loss: 0.2081
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9502 - loss: 0.1751
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9631 - loss: 0.1426
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9693 - loss: 0.1197
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9736 - loss: 0.1035
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9786 - loss: 0.0873
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.9801 - loss: 0.0784
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9829 - loss: 0.0700
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9836 - loss: 0.0643
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 23ms/step - binary_accuracy: 0.9858 - loss: 0.0590
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 28ms/step - binary_accuracy: 0.9865 - loss: 0.0543
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9886 - loss: 0.0493
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 31ms/step - binary_accuracy: 0.9882 - loss: 0.0478
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9892 - loss: 0.0434
Epoch 1/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.8968 - loss: 0.3909
Epoch 2/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.9442 - loss: 0.1299
Epoch 3/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 0.9990 - loss: 0.0419
Epoch 4/4
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - binary_accuracy: 1.0000 - loss: 0.0148
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.5031 - loss: 0.6942
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.5143 - loss: 0.6929
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.5190 - loss: 0.6922
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.5382 - loss: 0.6894
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.5862 - loss: 0.6776
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.6343 - loss: 0.6472
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.6875 - loss: 0.5989
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.7638 - loss: 0.5178
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.8164 - loss: 0.4349
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.8534 - loss: 0.3659
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 28ms/step - binary_accuracy: 0.8822 - loss: 0.3072
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 25ms/step - binary_accuracy: 0.9139 - loss: 0.2490
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9412 - loss: 0.1914
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - binary_accuracy: 0.9541 - loss: 0.1546
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9629 - loss: 0.1296
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9689 - loss: 0.1117
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9731 - loss: 0.0979
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9747 - loss: 0.0888
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9780 - loss: 0.0788
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9805 - loss: 0.0711
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9808 - loss: 0.0668
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9825 - loss: 0.0615
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9837 - loss: 0.0575
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9844 - loss: 0.0536
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 5s 24ms/step - binary_accuracy: 0.9862 - loss: 0.0501
Standalone layers trained.

Assembling Model A and Model B¶

After training the components, we assemble the full LinTrainableAssistedModelA and LinTrainableAssistedModelB.

Why assembly after training? • It enforces a strict separation between structure and optimisation. • It guarantees that A and B use the same shared resource (SR) resource. • It mirrors the exact dataflow of the classical assisted players.

At this point, the models are structurally complete.

Warning: If Model A and Model B are not perfectly aligned on shared layers, any apparent learning success is meaningless.

In [ ]:
# --- Install into full models ---
model_a = LinTrainableAssistedModelA(field_size=FIELD_SIZE, comms_size=COMMS_SIZE, sr_mode=SR_MODE_BOOTSTRAP_EVAL, seed=SEED, p_high=P_HIGH)
model_b = LinTrainableAssistedModelB(field_size=FIELD_SIZE, comms_size=COMMS_SIZE, sr_mode=SR_MODE_BOOTSTRAP_EVAL, seed=SEED, p_high=P_HIGH)

# Build models (required before weight transfer)
_ = model_a(tf.zeros((1, n2), tf.float32))
_dummy_gun = tf.zeros((1, n2), tf.float32)
_dummy_comm = tf.zeros((1, COMMS_SIZE), tf.float32)
_dummy_prev_meas_list = [tf.zeros((1, n2), tf.float32)]
_dummy_prev_out_list  = [tf.zeros((1, n2), tf.float32)]
_ = model_b([_dummy_gun, _dummy_comm, _dummy_prev_meas_list, _dummy_prev_out_list])

transfer_assisted_model_a_layer_weights(meas_layer_a, comb_layer_a, model_a)
transfer_assisted_model_b_layer_weights(meas_layer_b, comb_layer_b, model_b)

print("Installed weights into model_a/model_b")
print("model_a sr_mode:", model_a.sr_layer.mode)
print("model_b sr_mode:", model_b.sr_layer.mode)
Installed weights into model_a/model_b
model_a sr_mode: sample
model_b sr_mode: sample

Verification: neural vs classical assisted players¶

This is not a performance benchmark. It is a correctness check.

We compare: • Classical AssistedPlayers • Neural TrainableAssistedPlayers (sample mode)

Success criterion: Identical win rates and identical decision statistics within sampling noise.

Failure here means: • A broken measurement ordering • Incorrect handling of shared resource (SR) • Or a mismatch in combine logic

In [ ]:
# Force SR sample explicitly for evaluation
model_a.sr_layer.mode = "sample"
model_b.sr_layer.mode = "sample"


players = TrainableAssistedPlayers(layout, model_a=model_a, model_b=model_b)

layout_eval = GameLayout(
    field_size=FIELD_SIZE,
    comms_size=COMMS_SIZE,
    enemy_probability=0.5,
    channel_noise=0.0,
    number_of_games_in_tournament=2_000,
)
env = GameEnv(layout_eval)
t = Tournament(env, players, layout_eval)
log = t.tournament()
mean_reward, std_err = log.outcome()
print(f"Bootstrap tournament over {layout_eval.number_of_games_in_tournament}: {mean_reward:.4f} ± {std_err:.4f}")
Bootstrap tournament over 2000: 1.0000 ± 0.0000

Save weights¶

We save weights (not full .keras serialization) to avoid custom-object config issues.

In [9]:
model_a_path = models_dir / f"lin_model_a_bootstrap_f{FIELD_SIZE}_m{COMMS_SIZE}_p{P_HIGH:.2f}.weights.h5"
model_b_path = models_dir / f"lin_model_b_bootstrap_f{FIELD_SIZE}_m{COMMS_SIZE}_p{P_HIGH:.2f}.weights.h5"

model_a.save_weights(model_a_path)
model_b.save_weights(model_b_path)

print("Saved weights:")
print(" -", model_a_path)
print(" -", model_b_path)
Saved weights:
 - notebooks\models\lin_model_a_bootstrap_f4_m1_p1.00.weights.h5
 - notebooks\models\lin_model_b_bootstrap_f4_m1_p1.00.weights.h5