# ======================= # engine.py в одной ячейке # ======================= import numpy as np import torch import torch.nn.functional as F import torch.nn as nn PIECE_SYMBOLS = {0: '.', 1: 'w', 2: 'W', 3: 'b', 4: 'B'} def _dark(sq): return (sq // 8 + sq % 8) % 2 == 1 def _promotes(sq, c): return (c==1 and sq//8==0) or (c==-1 and sq//8==7) class Move: __slots__ = ('from_sq', 'to_sq', 'captures', 'is_king_move') def __init__(self, f, t, cap=None, king=False): self.from_sq, self.to_sq, self.captures, self.is_king_move = f, t, cap or [], king def __repr__(self): return f'Move({self.from_sq}→{self.to_sq}, cap={self.captures})' class Board: def __init__(self): self.pieces = np.zeros(64, dtype=np.int8) self.turn = 1 # 1 белые, -1 черные self.reset() # --------------------------------- def reset(self): self.pieces[:] = 0 for sq in [1, 3, 5, 7, 8, 10, 12, 14, 17, 19, 21, 23]: self.pieces[sq] = 3 # черные простые for sq in [40, 42, 44, 46, 49, 51, 53, 55, 56, 58, 60, 62]: self.pieces[sq] = 1 # белые простые self.turn = 1 def copy(self): b = Board() b.pieces = self.pieces.copy() b.turn = self.turn return b # --------------------------------- def _dark(self, sq): return (sq // 8 + sq % 8) % 2 == 1 # --------------------------------- def _man_captures(self, sq, color, captured=None): captured = captured or set() dirs = (-9, -7, 7, 9) res = [] for d in dirs: mid = sq + d dst = mid + d if 0 <= dst < 64 and self._dark(dst) and 0 <= mid < 64 and self._dark(mid): if mid not in captured and self.pieces[mid] in (3+color, 4+color): if self.pieces[dst] == 0: new_cap = captured | {mid} res.append((dst, new_cap)) res.extend(self._man_captures(dst, color, new_cap)) return res def _king_captures(self, sq, color, captured=None): captured = captured or set() res = [] for d in (-9, -7, 7, 9): first = None step = 1 while True: mid = sq + d * step if not (0 <= mid < 64 and self._dark(mid)): break piece_mid = self.pieces[mid] if piece_mid != 0: # первая встреченная фигура if first is None: # должна быть чужой и ещё не захвачена if mid not in captured and piece_mid in (3+color, 4+color): first = mid else: break # своя или уже взятая elif mid == first: pass # та же фигура else: break # вторая фигура — не дамка else: if first is not None and mid not in captured: dst = mid new_cap = captured | {first} res.append((dst, new_cap)) res.extend(self._king_captures(dst, color, new_cap)) step += 1 return res def _captures(self): color = self.turn moves = [] for sq in range(64): p = self.pieces[sq] if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): continue if p in (1,3): caps = self._man_captures(sq, color) for to, cap in caps: moves.append(Move(sq, to, list(cap))) else: caps = self._king_captures(sq, color) for to, cap in caps: moves.append(Move(sq, to, list(cap), is_king_move=True)) return moves def _quiet(self): color = self.turn moves = [] for sq in range(64): p = self.pieces[sq] if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): continue if p in (1,3): # man dirs = (-9, -7) if color == 1 else (9, 7) for d in dirs: dst = sq + d if 0 <= dst < 64 and self._dark(dst) and self.pieces[dst] == 0: moves.append(Move(sq, dst)) else: # king for d in (-9, -7, 7, 9): step = 1 while True: dst = sq + d * step if not (0 <= dst < 64 and self._dark(dst)): break if self.pieces[dst] == 0: moves.append(Move(sq, dst, is_king_move=True)) else: break step += 1 return moves # --------------------------------- def legal_moves(self): caps = self._captures() return caps if caps else self._quiet() # --------------------------------- def make_move(self, move): p = self.pieces[move.from_sq] self.pieces[move.from_sq] = 0 if not move.is_king_move and (move.to_sq // 8 == 0 and self.turn == 1 or move.to_sq // 8 == 7 and self.turn == -1): p += 1 # превращение в дамку self.pieces[move.to_sq] = p for cap_sq in move.captures: self.pieces[cap_sq] = 0 self.turn = -self.turn # --------------------------------- def is_terminal(self): legal = self.legal_moves() if not legal: return True, -self.turn # победа противника # можно добавить правило 15 ходов, но пока только пат return False, 0 # --------------------------------- def __str__(self): rows = [] for r in range(8): row = [PIECE_SYMBOLS[self.pieces[r*8+c]] if self._dark(r*8+c) else ' ' for c in range(8)] rows.append(" ".join(row)) return "\n".join(rows) # --------------------------------- # ResNet + кодирование class ResidualBlock(nn.Module): def __init__(self, ch=64): super().__init__() self.conv1=nn.Conv2d(ch, ch, 3, padding=1, bias=False) self.bn1=nn.BatchNorm2d(ch) self.conv2=nn.Conv2d(ch, ch, 3, padding=1, bias=False) self.bn2=nn.BatchNorm2d(ch) def forward(self, x): return F.relu(x + self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))) class ChekaNet(nn.Module): def __init__(self, blocks=3, channels=64): super().__init__() self.conv_in=nn.Conv2d(5, channels, 3, padding=1, bias=False) self.bn_in=nn.BatchNorm2d(channels) self.residuals=nn.Sequential(*[ResidualBlock(channels) for _ in range(blocks)]) self.policy_conv=nn.Conv2d(channels, 1, 1) self.value_conv=nn.Conv2d(channels, 1, 1) self.value_fc=nn.Sequential(nn.Flatten(), nn.Linear(64,128), nn.ReLU(), nn.Linear(128,1), nn.Tanh()) def forward(self, x): x=F.relu(self.bn_in(self.conv_in(x))) x=self.residuals(x) pol=self.policy_conv(x).squeeze(1).view(x.size(0),-1) val=self.value_fc(self.value_conv(x).squeeze(1)) return pol,val def board_to_tensor(b): planes=np.zeros((5,8,8),np.float32) for sq in range(64): r,c=sq//8,sq%8 p=b.pieces[sq] if p==1: planes[0,r,c]=1 elif p==2: planes[1,r,c]=1 elif p==3: planes[2,r,c]=1 elif p==4: planes[3,r,c]=1 planes[4]=1.0 if b.turn==1 else 0.0 return torch.from_numpy(planes) # ---------- MCTS + выбор хода ---------- import math, random class MCTSNode: def __init__(self, board, parent=None, prior=0): self.board = board.copy() self.parent, self.P, self.N, self.W, self.children = parent, prior, 0, 0.0, {} def Q(self): return self.W / (self.N + 1e-8) def U(self, c_puct=1.0): return c_puct * self.P * math.sqrt(self.parent.N) / (1 + self.N) def is_leaf(self): return len(self.children) == 0 def expand_leaf(node, net, device): board = node.board legal = board.legal_moves() if not legal: return -1 if board.turn == 1 else 1 tensor = board_to_tensor(board).unsqueeze(0).to(device) with torch.no_grad(): logits, v = net(tensor) logits = logits[0].cpu().numpy() v = v.item() mask = np.full(64, -np.inf) for m in legal: mask[m.to_sq] = logits[m.to_sq] probs = torch.softmax(torch.tensor(mask), dim=0).numpy() for m in legal: child = MCTSNode(board.copy(), parent=node, prior=probs[m.to_sq]) child.board.make_move(m) node.children[m] = child return v def backup(node, v): while node: node.N += 1 node.W += v v = -v node = node.parent def select_move(board, net, device, sims=400, c_puct=1.0, temp=0.0): root = MCTSNode(board) for _ in range(sims): node = root while not node.is_leaf(): node = max(node.children.values(), key=lambda n: n.Q() + n.U(c_puct)) v = expand_leaf(node, net, device) backup(node, v) visits = [(m, c.N) for m, c in root.children.items()] if temp == 0: move = max(visits, key=lambda x: x[1])[0] else: counts = np.array([v[1] for v in visits]) ** (1 / temp) counts /= counts.sum() move = random.choices([v[0] for v in visits], counts)[0] return move, root