PyrTrainableAssistedModelA¶
Role: Keras model implementing the Player-A pyramid assisted architecture with per-level measurement/combine layers and per-level shared-resource assisted layers, producing a single communication logit per batch element.
Location: Q_Sea_Battle.pyr_trainable_assisted_model_a.PyrTrainableAssistedModelA
Derived constraints¶
- \(n2\) (the flattened field length) is inferred from
game_layoutas eitherint(game_layout.n2)orint(game_layout.field_size) * int(game_layout.field_size), and must be positive and a power of two. - \(m\) (communication size) is inferred as
int(game_layout.comms_size)orint(game_layout.M)or defaults to1, and must satisfy \(m = 1\). - Model depth \(K = \log_2(n2)\) (an integer), and the number of per-level measurement layers, combine layers, and shared-resource layers is exactly \(K\).
- Input batch
field_batchmust be rank-2 with shape \((B, n2)\); \(B\) is batch size.
Constructor¶
| Parameter | Type | Description |
|---|---|---|
game_layout |
Any, GameLayout-like object |
Must expose either n2: int or field_size: int (to derive \(n2\)), and may expose comms_size: int or M: int (to derive \(m\)). Constraints: derived \(n2\) must be positive and a power of two; derived \(m\) must equal 1. Shape: not applicable. |
p_high |
float, unconstrained |
Correlation parameter forwarded to each PRAssistedLayer. |
sr_mode |
str, unconstrained |
Mode forwarded to each PRAssistedLayer as mode. |
measure_layers |
Optional[Sequence[tf.keras.layers.Layer]], length \(K\) if provided |
Optional per-level measurement layers. Constraint: if not None, len(measure_layers) == K. |
combine_layers |
Optional[Sequence[tf.keras.layers.Layer]], length \(K\) if provided |
Optional per-level combine layers. Constraint: if not None, len(combine_layers) == K. |
name |
Optional[str], unconstrained |
Optional Keras model name. |
Preconditions
game_layoutprovides enough attributes to infer \(n2\) and \(m\) as described under Derived constraints.- Derived \(n2\) is a power of two and greater than 0.
- Derived \(m == 1\).
- If
measure_layersis provided, its length equals \(K\). - If
combine_layersis provided, its length equals \(K\).
Postconditions
self.n2: intis set to inferred \(n2\).self.M: intis set to inferred \(m\).self.depth: intis set to \(K = \log_2(n2)\).self.measure_layers: List[tf.keras.layers.Layer]has length \(K\); defaults toPyrMeasurementLayerA()repeated \(K\) times if not provided.self.combine_layers: List[tf.keras.layers.Layer]has length \(K\); defaults toPyrCombineLayerA()repeated \(K\) times if not provided.self.measure_layeraliasesself.measure_layers[0]andself.combine_layeraliasesself.combine_layers[0](legacy compatibility).self.sr_layers: List[PRAssistedLayer]has length \(K\) withresource_indexequal to the level index, andlengthequal to \(n2 / 2^{(level+1)}\).
Errors
- Raises
ValueErrorif derived $m \ne 1`. - Raises
ValueErrorif derived $n2 \le 0`. - Raises
ValueErrorif derived \(n2\) is not a power of two. - Raises
ValueErrorifmeasure_layersis provided andlen(measure_layers) != K. - Raises
ValueErrorifcombine_layersis provided andlen(combine_layers) != K.
Example
import tensorflow as tf
from Q_Sea_Battle.pyr_trainable_assisted_model_a import PyrTrainableAssistedModelA
class Layout:
field_size = 4
comms_size = 1
model = PyrTrainableAssistedModelA(game_layout=Layout(), p_high=0.9, sr_mode="sample")
x = tf.zeros((2, model.n2), dtype=tf.float32)
logits = model(x)
Public Methods¶
call¶
- Signature:
call(self, field_batch: tf.Tensor) -> tf.Tensor
Parameters
field_batch:tf.Tensor, dtype not specified, shape \((B, n2)\), rank must be 2.
Returns
comm_logits:tf.Tensor, dtype float32, shape \((B, 1)\).
Errors
- Propagates
ValueErrorfromcompute_with_internaliffield_batchis not rank-2.
Notes
- Delegates computation to
compute_with_internaland returns only the first element (communication logits).
compute_with_internal¶
- Signature:
compute_with_internal(self, field_batch: tf.Tensor) -> Tuple[tf.Tensor, List[tf.Tensor], List[tf.Tensor]]
Parameters
field_batch:tf.Tensor, dtype not specified, shape \((B, n2)\), rank must be 2.
Returns
comm_logits:tf.Tensor, dtype float32, shape \((B, 1)\); computed as((clip(state,0,1) * 2 - 1) * 10)after the final level, where finalstateis expected to be shape \((B, 1)\).measurements:List[tf.Tensor], length \(K\), each elementtf.Tensordtype float32, shape \((B, L/2)\) for that level (exact \(L\) per level is not validated in code).outcomes:List[tf.Tensor], length \(K\), each elementtf.Tensordtype float32, shape not specified (depends onPRAssistedLayeroutput).
Errors
- Raises
ValueErroriffield_batchis not rank-2, i.e.,x.shape.rank != 2.
Data & State¶
n2:int, inferred field length; constraint: positive power of two.M:int, inferred communication size; constraint: equals 1.depth:int, \(K = \log_2(n2)\).measure_layers:List[tf.keras.layers.Layer], length \(K\); per-level measurement layers.combine_layers:List[tf.keras.layers.Layer], length \(K\); per-level combine layers.measure_layer:tf.keras.layers.Layer, alias tomeasure_layers[0](legacy compatibility).combine_layer:tf.keras.layers.Layer, alias tocombine_layers[0](legacy compatibility).sr_layers:List[PRAssistedLayer], length \(K\); per-level shared-resource assisted layers withresource_index=level.
Planned (design-spec)¶
- Not specified.
Deviations¶
- Not specified.
Notes for Contributors¶
- Input validation in
compute_with_internalchecks only tensor rank, not the second dimension size; mismatched \((B, n2)\) sizes may fail later inside measurement/combine layers. - The final conversion from the last-level
statetocomm_logitsis a hard mapping to logits \(\{-10, +10\}\) after clipping to \([0,1]\); this is intentionally non-trainable per module docstring. - The shared-resource layer invocation always passes
first_measurementas ones and passes zero tensors for previous measurement/outcome, regardless of level.
Related¶
Q_Sea_Battle.pyr_measurement_layer_a.PyrMeasurementLayerAQ_Sea_Battle.pyr_combine_layer_a.PyrCombineLayerAQ_Sea_Battle.pr_assisted_layer.PRAssistedLayer
Changelog¶
- Not specified.