|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
|
|
|
PIECE_SYMBOLS = {0: '.', 1: 'w', 2: 'W', 3: 'b', 4: 'B'} |
|
|
|
|
|
def _dark(sq): return (sq // 8 + sq % 8) % 2 == 1 |
|
|
def _promotes(sq, c): return (c==1 and sq//8==0) or (c==-1 and sq//8==7) |
|
|
|
|
|
class Move: |
|
|
__slots__ = ('from_sq', 'to_sq', 'captures', 'is_king_move') |
|
|
def __init__(self, f, t, cap=None, king=False): |
|
|
self.from_sq, self.to_sq, self.captures, self.is_king_move = f, t, cap or [], king |
|
|
def __repr__(self): return f'Move({self.from_sq}→{self.to_sq}, cap={self.captures})' |
|
|
|
|
|
class Board: |
|
|
def __init__(self): |
|
|
self.pieces = np.zeros(64, dtype=np.int8) |
|
|
self.turn = 1 |
|
|
self.reset() |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
self.pieces[:] = 0 |
|
|
for sq in [1, 3, 5, 7, 8, 10, 12, 14, 17, 19, 21, 23]: |
|
|
self.pieces[sq] = 3 |
|
|
for sq in [40, 42, 44, 46, 49, 51, 53, 55, 56, 58, 60, 62]: |
|
|
self.pieces[sq] = 1 |
|
|
self.turn = 1 |
|
|
|
|
|
def copy(self): |
|
|
b = Board() |
|
|
b.pieces = self.pieces.copy() |
|
|
b.turn = self.turn |
|
|
return b |
|
|
|
|
|
|
|
|
def _dark(self, sq): return (sq // 8 + sq % 8) % 2 == 1 |
|
|
|
|
|
|
|
|
def _man_captures(self, sq, color, captured=None): |
|
|
captured = captured or set() |
|
|
dirs = (-9, -7, 7, 9) |
|
|
res = [] |
|
|
for d in dirs: |
|
|
mid = sq + d |
|
|
dst = mid + d |
|
|
if 0 <= dst < 64 and self._dark(dst) and 0 <= mid < 64 and self._dark(mid): |
|
|
if mid not in captured and self.pieces[mid] in (3+color, 4+color): |
|
|
if self.pieces[dst] == 0: |
|
|
new_cap = captured | {mid} |
|
|
res.append((dst, new_cap)) |
|
|
res.extend(self._man_captures(dst, color, new_cap)) |
|
|
return res |
|
|
|
|
|
def _king_captures(self, sq, color, captured=None): |
|
|
captured = captured or set() |
|
|
res = [] |
|
|
for d in (-9, -7, 7, 9): |
|
|
first = None |
|
|
step = 1 |
|
|
while True: |
|
|
mid = sq + d * step |
|
|
if not (0 <= mid < 64 and self._dark(mid)): |
|
|
break |
|
|
piece_mid = self.pieces[mid] |
|
|
if piece_mid != 0: |
|
|
|
|
|
if first is None: |
|
|
|
|
|
if mid not in captured and piece_mid in (3+color, 4+color): |
|
|
first = mid |
|
|
else: |
|
|
break |
|
|
elif mid == first: |
|
|
pass |
|
|
else: |
|
|
break |
|
|
else: |
|
|
if first is not None and mid not in captured: |
|
|
dst = mid |
|
|
new_cap = captured | {first} |
|
|
res.append((dst, new_cap)) |
|
|
res.extend(self._king_captures(dst, color, new_cap)) |
|
|
step += 1 |
|
|
return res |
|
|
|
|
|
|
|
|
def _captures(self): |
|
|
color = self.turn |
|
|
moves = [] |
|
|
for sq in range(64): |
|
|
p = self.pieces[sq] |
|
|
if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): |
|
|
continue |
|
|
if p in (1,3): |
|
|
caps = self._man_captures(sq, color) |
|
|
for to, cap in caps: |
|
|
moves.append(Move(sq, to, list(cap))) |
|
|
else: |
|
|
caps = self._king_captures(sq, color) |
|
|
for to, cap in caps: |
|
|
moves.append(Move(sq, to, list(cap), is_king_move=True)) |
|
|
return moves |
|
|
|
|
|
def _quiet(self): |
|
|
color = self.turn |
|
|
moves = [] |
|
|
for sq in range(64): |
|
|
p = self.pieces[sq] |
|
|
if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): |
|
|
continue |
|
|
if p in (1,3): |
|
|
dirs = (-9, -7) if color == 1 else (9, 7) |
|
|
for d in dirs: |
|
|
dst = sq + d |
|
|
if 0 <= dst < 64 and self._dark(dst) and self.pieces[dst] == 0: |
|
|
moves.append(Move(sq, dst)) |
|
|
else: |
|
|
for d in (-9, -7, 7, 9): |
|
|
step = 1 |
|
|
while True: |
|
|
dst = sq + d * step |
|
|
if not (0 <= dst < 64 and self._dark(dst)): |
|
|
break |
|
|
if self.pieces[dst] == 0: |
|
|
moves.append(Move(sq, dst, is_king_move=True)) |
|
|
else: |
|
|
break |
|
|
step += 1 |
|
|
return moves |
|
|
|
|
|
|
|
|
def legal_moves(self): |
|
|
caps = self._captures() |
|
|
return caps if caps else self._quiet() |
|
|
|
|
|
|
|
|
def make_move(self, move): |
|
|
p = self.pieces[move.from_sq] |
|
|
self.pieces[move.from_sq] = 0 |
|
|
if not move.is_king_move and (move.to_sq // 8 == 0 and self.turn == 1 or move.to_sq // 8 == 7 and self.turn == -1): |
|
|
p += 1 |
|
|
self.pieces[move.to_sq] = p |
|
|
for cap_sq in move.captures: |
|
|
self.pieces[cap_sq] = 0 |
|
|
self.turn = -self.turn |
|
|
|
|
|
|
|
|
def is_terminal(self): |
|
|
legal = self.legal_moves() |
|
|
if not legal: |
|
|
return True, -self.turn |
|
|
|
|
|
return False, 0 |
|
|
|
|
|
|
|
|
def __str__(self): |
|
|
rows = [] |
|
|
for r in range(8): |
|
|
row = [PIECE_SYMBOLS[self.pieces[r*8+c]] if self._dark(r*8+c) else ' ' for c in range(8)] |
|
|
rows.append(" ".join(row)) |
|
|
return "\n".join(rows) |
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, ch=64): |
|
|
super().__init__() |
|
|
self.conv1=nn.Conv2d(ch, ch, 3, padding=1, bias=False) |
|
|
self.bn1=nn.BatchNorm2d(ch) |
|
|
self.conv2=nn.Conv2d(ch, ch, 3, padding=1, bias=False) |
|
|
self.bn2=nn.BatchNorm2d(ch) |
|
|
def forward(self, x): |
|
|
return F.relu(x + self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))) |
|
|
|
|
|
class ChekaNet(nn.Module): |
|
|
def __init__(self, blocks=3, channels=64): |
|
|
super().__init__() |
|
|
self.conv_in=nn.Conv2d(5, channels, 3, padding=1, bias=False) |
|
|
self.bn_in=nn.BatchNorm2d(channels) |
|
|
self.residuals=nn.Sequential(*[ResidualBlock(channels) for _ in range(blocks)]) |
|
|
self.policy_conv=nn.Conv2d(channels, 1, 1) |
|
|
self.value_conv=nn.Conv2d(channels, 1, 1) |
|
|
self.value_fc=nn.Sequential(nn.Flatten(), nn.Linear(64,128), nn.ReLU(), nn.Linear(128,1), nn.Tanh()) |
|
|
def forward(self, x): |
|
|
x=F.relu(self.bn_in(self.conv_in(x))) |
|
|
x=self.residuals(x) |
|
|
pol=self.policy_conv(x).squeeze(1).view(x.size(0),-1) |
|
|
val=self.value_fc(self.value_conv(x).squeeze(1)) |
|
|
return pol,val |
|
|
|
|
|
def board_to_tensor(b): |
|
|
planes=np.zeros((5,8,8),np.float32) |
|
|
for sq in range(64): |
|
|
r,c=sq//8,sq%8 |
|
|
p=b.pieces[sq] |
|
|
if p==1: planes[0,r,c]=1 |
|
|
elif p==2: planes[1,r,c]=1 |
|
|
elif p==3: planes[2,r,c]=1 |
|
|
elif p==4: planes[3,r,c]=1 |
|
|
planes[4]=1.0 if b.turn==1 else 0.0 |
|
|
return torch.from_numpy(planes) |
|
|
|
|
|
import math, random |
|
|
|
|
|
class MCTSNode: |
|
|
def __init__(self, board, parent=None, prior=0): |
|
|
self.board = board.copy() |
|
|
self.parent, self.P, self.N, self.W, self.children = parent, prior, 0, 0.0, {} |
|
|
def Q(self): return self.W / (self.N + 1e-8) |
|
|
def U(self, c_puct=1.0): return c_puct * self.P * math.sqrt(self.parent.N) / (1 + self.N) |
|
|
def is_leaf(self): return len(self.children) == 0 |
|
|
|
|
|
def expand_leaf(node, net, device): |
|
|
board = node.board |
|
|
legal = board.legal_moves() |
|
|
if not legal: |
|
|
return -1 if board.turn == 1 else 1 |
|
|
tensor = board_to_tensor(board).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
logits, v = net(tensor) |
|
|
logits = logits[0].cpu().numpy() |
|
|
v = v.item() |
|
|
mask = np.full(64, -np.inf) |
|
|
for m in legal: mask[m.to_sq] = logits[m.to_sq] |
|
|
probs = torch.softmax(torch.tensor(mask), dim=0).numpy() |
|
|
for m in legal: |
|
|
child = MCTSNode(board.copy(), parent=node, prior=probs[m.to_sq]) |
|
|
child.board.make_move(m) |
|
|
node.children[m] = child |
|
|
return v |
|
|
|
|
|
def backup(node, v): |
|
|
while node: |
|
|
node.N += 1 |
|
|
node.W += v |
|
|
v = -v |
|
|
node = node.parent |
|
|
|
|
|
def select_move(board, net, device, sims=400, c_puct=1.0, temp=0.0): |
|
|
root = MCTSNode(board) |
|
|
for _ in range(sims): |
|
|
node = root |
|
|
while not node.is_leaf(): node = max(node.children.values(), key=lambda n: n.Q() + n.U(c_puct)) |
|
|
v = expand_leaf(node, net, device) |
|
|
backup(node, v) |
|
|
visits = [(m, c.N) for m, c in root.children.items()] |
|
|
if temp == 0: |
|
|
move = max(visits, key=lambda x: x[1])[0] |
|
|
else: |
|
|
counts = np.array([v[1] for v in visits]) ** (1 / temp) |
|
|
counts /= counts.sum() |
|
|
move = random.choices([v[0] for v in visits], counts)[0] |
|
|
return move, root |
|
|
|