SnowFlash383935's picture
Update engine.py
fdd8467 verified
# =======================
# engine.py в одной ячейке
# =======================
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
PIECE_SYMBOLS = {0: '.', 1: 'w', 2: 'W', 3: 'b', 4: 'B'}
def _dark(sq): return (sq // 8 + sq % 8) % 2 == 1
def _promotes(sq, c): return (c==1 and sq//8==0) or (c==-1 and sq//8==7)
class Move:
__slots__ = ('from_sq', 'to_sq', 'captures', 'is_king_move')
def __init__(self, f, t, cap=None, king=False):
self.from_sq, self.to_sq, self.captures, self.is_king_move = f, t, cap or [], king
def __repr__(self): return f'Move({self.from_sq}{self.to_sq}, cap={self.captures})'
class Board:
def __init__(self):
self.pieces = np.zeros(64, dtype=np.int8)
self.turn = 1 # 1 белые, -1 черные
self.reset()
# ---------------------------------
def reset(self):
self.pieces[:] = 0
for sq in [1, 3, 5, 7, 8, 10, 12, 14, 17, 19, 21, 23]:
self.pieces[sq] = 3 # черные простые
for sq in [40, 42, 44, 46, 49, 51, 53, 55, 56, 58, 60, 62]:
self.pieces[sq] = 1 # белые простые
self.turn = 1
def copy(self):
b = Board()
b.pieces = self.pieces.copy()
b.turn = self.turn
return b
# ---------------------------------
def _dark(self, sq): return (sq // 8 + sq % 8) % 2 == 1
# ---------------------------------
def _man_captures(self, sq, color, captured=None):
captured = captured or set()
dirs = (-9, -7, 7, 9)
res = []
for d in dirs:
mid = sq + d
dst = mid + d
if 0 <= dst < 64 and self._dark(dst) and 0 <= mid < 64 and self._dark(mid):
if mid not in captured and self.pieces[mid] in (3+color, 4+color):
if self.pieces[dst] == 0:
new_cap = captured | {mid}
res.append((dst, new_cap))
res.extend(self._man_captures(dst, color, new_cap))
return res
def _king_captures(self, sq, color, captured=None):
captured = captured or set()
res = []
for d in (-9, -7, 7, 9):
first = None
step = 1
while True:
mid = sq + d * step
if not (0 <= mid < 64 and self._dark(mid)):
break
piece_mid = self.pieces[mid]
if piece_mid != 0:
# первая встреченная фигура
if first is None:
# должна быть чужой и ещё не захвачена
if mid not in captured and piece_mid in (3+color, 4+color):
first = mid
else:
break # своя или уже взятая
elif mid == first:
pass # та же фигура
else:
break # вторая фигура — не дамка
else:
if first is not None and mid not in captured:
dst = mid
new_cap = captured | {first}
res.append((dst, new_cap))
res.extend(self._king_captures(dst, color, new_cap))
step += 1
return res
def _captures(self):
color = self.turn
moves = []
for sq in range(64):
p = self.pieces[sq]
if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1):
continue
if p in (1,3):
caps = self._man_captures(sq, color)
for to, cap in caps:
moves.append(Move(sq, to, list(cap)))
else:
caps = self._king_captures(sq, color)
for to, cap in caps:
moves.append(Move(sq, to, list(cap), is_king_move=True))
return moves
def _quiet(self):
color = self.turn
moves = []
for sq in range(64):
p = self.pieces[sq]
if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1):
continue
if p in (1,3): # man
dirs = (-9, -7) if color == 1 else (9, 7)
for d in dirs:
dst = sq + d
if 0 <= dst < 64 and self._dark(dst) and self.pieces[dst] == 0:
moves.append(Move(sq, dst))
else: # king
for d in (-9, -7, 7, 9):
step = 1
while True:
dst = sq + d * step
if not (0 <= dst < 64 and self._dark(dst)):
break
if self.pieces[dst] == 0:
moves.append(Move(sq, dst, is_king_move=True))
else:
break
step += 1
return moves
# ---------------------------------
def legal_moves(self):
caps = self._captures()
return caps if caps else self._quiet()
# ---------------------------------
def make_move(self, move):
p = self.pieces[move.from_sq]
self.pieces[move.from_sq] = 0
if not move.is_king_move and (move.to_sq // 8 == 0 and self.turn == 1 or move.to_sq // 8 == 7 and self.turn == -1):
p += 1 # превращение в дамку
self.pieces[move.to_sq] = p
for cap_sq in move.captures:
self.pieces[cap_sq] = 0
self.turn = -self.turn
# ---------------------------------
def is_terminal(self):
legal = self.legal_moves()
if not legal:
return True, -self.turn # победа противника
# можно добавить правило 15 ходов, но пока только пат
return False, 0
# ---------------------------------
def __str__(self):
rows = []
for r in range(8):
row = [PIECE_SYMBOLS[self.pieces[r*8+c]] if self._dark(r*8+c) else ' ' for c in range(8)]
rows.append(" ".join(row))
return "\n".join(rows)
# ---------------------------------
# ResNet + кодирование
class ResidualBlock(nn.Module):
def __init__(self, ch=64):
super().__init__()
self.conv1=nn.Conv2d(ch, ch, 3, padding=1, bias=False)
self.bn1=nn.BatchNorm2d(ch)
self.conv2=nn.Conv2d(ch, ch, 3, padding=1, bias=False)
self.bn2=nn.BatchNorm2d(ch)
def forward(self, x):
return F.relu(x + self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x))))))
class ChekaNet(nn.Module):
def __init__(self, blocks=3, channels=64):
super().__init__()
self.conv_in=nn.Conv2d(5, channels, 3, padding=1, bias=False)
self.bn_in=nn.BatchNorm2d(channels)
self.residuals=nn.Sequential(*[ResidualBlock(channels) for _ in range(blocks)])
self.policy_conv=nn.Conv2d(channels, 1, 1)
self.value_conv=nn.Conv2d(channels, 1, 1)
self.value_fc=nn.Sequential(nn.Flatten(), nn.Linear(64,128), nn.ReLU(), nn.Linear(128,1), nn.Tanh())
def forward(self, x):
x=F.relu(self.bn_in(self.conv_in(x)))
x=self.residuals(x)
pol=self.policy_conv(x).squeeze(1).view(x.size(0),-1)
val=self.value_fc(self.value_conv(x).squeeze(1))
return pol,val
def board_to_tensor(b):
planes=np.zeros((5,8,8),np.float32)
for sq in range(64):
r,c=sq//8,sq%8
p=b.pieces[sq]
if p==1: planes[0,r,c]=1
elif p==2: planes[1,r,c]=1
elif p==3: planes[2,r,c]=1
elif p==4: planes[3,r,c]=1
planes[4]=1.0 if b.turn==1 else 0.0
return torch.from_numpy(planes)
# ---------- MCTS + выбор хода ----------
import math, random
class MCTSNode:
def __init__(self, board, parent=None, prior=0):
self.board = board.copy()
self.parent, self.P, self.N, self.W, self.children = parent, prior, 0, 0.0, {}
def Q(self): return self.W / (self.N + 1e-8)
def U(self, c_puct=1.0): return c_puct * self.P * math.sqrt(self.parent.N) / (1 + self.N)
def is_leaf(self): return len(self.children) == 0
def expand_leaf(node, net, device):
board = node.board
legal = board.legal_moves()
if not legal:
return -1 if board.turn == 1 else 1
tensor = board_to_tensor(board).unsqueeze(0).to(device)
with torch.no_grad():
logits, v = net(tensor)
logits = logits[0].cpu().numpy()
v = v.item()
mask = np.full(64, -np.inf)
for m in legal: mask[m.to_sq] = logits[m.to_sq]
probs = torch.softmax(torch.tensor(mask), dim=0).numpy()
for m in legal:
child = MCTSNode(board.copy(), parent=node, prior=probs[m.to_sq])
child.board.make_move(m)
node.children[m] = child
return v
def backup(node, v):
while node:
node.N += 1
node.W += v
v = -v
node = node.parent
def select_move(board, net, device, sims=400, c_puct=1.0, temp=0.0):
root = MCTSNode(board)
for _ in range(sims):
node = root
while not node.is_leaf(): node = max(node.children.values(), key=lambda n: n.Q() + n.U(c_puct))
v = expand_leaf(node, net, device)
backup(node, v)
visits = [(m, c.N) for m, c in root.children.items()]
if temp == 0:
move = max(visits, key=lambda x: x[1])[0]
else:
counts = np.array([v[1] for v in visits]) ** (1 / temp)
counts /= counts.sum()
move = random.choices([v[0] for v in visits], counts)[0]
return move, root