diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..d251795a1 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,153 @@ +name: Build and Test + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + workflow_dispatch: + +jobs: + build-linux: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y cmake build-essential zlib1g-dev libzip-dev opencl-headers ocl-icd-opencl-dev + + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + key: ${{ runner.os }}-cmake-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake- + + - name: Configure CMake + working-directory: cpp + # Use '-DCMAKE_CXX_FLAGS_RELEASE="-s"' to strip debug symbols. Otherwise, the executable is too big + run: | + cmake . -DUSE_BACKEND=OPENCL -DNO_GIT_REVISION=1 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS_RELEASE="-s" + + - name: Build + working-directory: cpp + run: | + make -j$(nproc) + + - name: Run tests + run: | + cpp/katago runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-linux-opencl + path: cpp/katago + + build-macos: + runs-on: macos-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + brew install zlib libzip opencl-headers + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + cpp/build.ninja + cpp/.ninja_deps + cpp/.ninja_log + key: ${{ runner.os }}-cmake-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake- + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G Ninja -DUSE_BACKEND=OPENCL -DNO_GIT_REVISION=1 -DCMAKE_BUILD_TYPE=Release + - name: Build + working-directory: cpp + run: | + ninja + - name: Run tests + run: | + cpp/katago runtests + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-macos-opencl + path: cpp/katago + + build-windows: + runs-on: windows-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup MSVC + uses: microsoft/setup-msbuild@v2 + + - name: Cache vcpkg packages + uses: actions/cache@v4 + with: + path: | + ${{ env.VCPKG_INSTALLATION_ROOT }}/installed + ${{ env.VCPKG_INSTALLATION_ROOT }}/packages + key: ${{ runner.os }}-vcpkg-${{ hashFiles('**/vcpkg.json') }}-opencl + restore-keys: | + ${{ runner.os }}-vcpkg- + + - name: Install vcpkg dependencies + run: | + vcpkg install zlib:x64-windows libzip:x64-windows opencl:x64-windows + + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G "Visual Studio 17 2022" -A x64 ` + -DUSE_BACKEND=OPENCL ` + -DNO_GIT_REVISION=1 ` + -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" + + - name: Build + working-directory: cpp + run: | + cmake --build . --config Release -j 4 + + - name: Copy required DLLs + working-directory: cpp + run: | + $vcpkgRoot = $env:VCPKG_INSTALLATION_ROOT + Copy-Item "$vcpkgRoot/installed/x64-windows/bin/*.dll" -Destination "Release/" -ErrorAction SilentlyContinue + + - name: Run tests + run: | + cpp/Release/katago.exe runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-windows-opencl + path: cpp/Release/ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c523f28b0..be0f6f1e0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -38,6 +38,7 @@ set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") set(USE_AVX2 0 CACHE BOOL "Compile with AVX2") set(USE_BIGGER_BOARDS_EXPENSIVE 0 CACHE BOOL "Allow boards up to size 50. Compiling with this will use more memory and slow down KataGo, even when playing on boards of size 19.") +set(DOTS_GAME 0 CACHE BOOL "Configure KataGo to compile for Dots Game with possibility to run Go as well. It configures more optimal field sizes") set(USE_CACHE_TENSORRT_PLAN 0 CACHE BOOL "Use TENSORRT plan cache. May use a lot of disk space. Only applies when USE_BACKEND is TENSORRT.") mark_as_advanced(USE_CACHE_TENSORRT_PLAN) @@ -231,8 +232,10 @@ add_executable(katago core/threadtest.cpp core/timer.cpp game/board.cpp + game/dotsfield.cpp game/rules.cpp game/boardhistory.cpp + game/dotsboardhistory.cpp game/graphhash.cpp dataio/sgf.cpp dataio/numpywrite.cpp @@ -242,6 +245,7 @@ add_executable(katago dataio/homedata.cpp dataio/files.cpp neuralnet/nninputs.cpp + neuralnet/nninputsdots.cpp neuralnet/sgfmetadata.cpp neuralnet/modelversion.cpp neuralnet/nneval.cpp @@ -280,6 +284,11 @@ add_executable(katago ${GIT_HEADER_FILE_ALWAYS_UPDATED} tests/testboardarea.cpp tests/testboardbasic.cpp + tests/testdotsutils.cpp + tests/testdotsbasic.cpp + tests/testdotsstartposes.cpp + tests/testdotsstress.cpp + tests/testdotsextra.cpp tests/testbook.cpp tests/testcommon.cpp tests/testconfig.cpp @@ -447,9 +456,17 @@ elseif(USE_BACKEND STREQUAL "EIGEN") endif() endif() -if(USE_BIGGER_BOARDS_EXPENSIVE) - target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN=50) -endif() +if(DOTS_GAME) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_X=39) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_Y=39) + if(USE_BIGGER_BOARDS_EXPENSIVE) + message(SEND_ERROR "USE_BIGGER_BOARDS_EXPENSIVE is not yet supported for Dots Game") + endif() +else() + if(USE_BIGGER_BOARDS_EXPENSIVE) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN=50) + endif() +endif(DOTS_GAME) if(NO_GIT_REVISION AND (NOT BUILD_DISTRIBUTED)) target_compile_definitions(katago PRIVATE NO_GIT_REVISION) diff --git a/cpp/book/book.cpp b/cpp/book/book.cpp index 32b30962f..04dd18c2e 100644 --- a/cpp/book/book.cpp +++ b/cpp/book/book.cpp @@ -124,15 +124,15 @@ static Hash128 getExtraPosHash(const Board& board) { for(int y = 0; y& symmetriesRet, int bookVersion) { - Board boardsBySym[SymmetryHelpers::NUM_SYMMETRIES]; - BoardHistory histsBySym[SymmetryHelpers::NUM_SYMMETRIES]; + vector boardsBySym; + vector histsBySym; Hash128 accums[SymmetryHelpers::NUM_SYMMETRIES]; // Make sure the book all matches orientation for rectangular boards. @@ -143,8 +143,8 @@ void BookHash::getHashAndSymmetry(const BoardHistory& hist, int repBound, BookHa SymmetryHelpers::NUM_SYMMETRIES_WITHOUT_TRANSPOSE : SymmetryHelpers::NUM_SYMMETRIES; for(int symmetry = 0; symmetry < numSymmetries; symmetry++) { - boardsBySym[symmetry] = SymmetryHelpers::getSymBoard(hist.initialBoard,symmetry); - histsBySym[symmetry] = BoardHistory(boardsBySym[symmetry], hist.initialPla, hist.rules, hist.initialEncorePhase); + boardsBySym.emplace_back(SymmetryHelpers::getSymBoard(hist.initialBoard,symmetry)); + histsBySym.emplace_back(boardsBySym[symmetry], hist.initialPla, hist.rules, hist.initialEncorePhase); accums[symmetry] = Hash128(); } @@ -2621,7 +2621,7 @@ int64_t Book::exportToHtmlDir( const int symmetry = 0; SymBookNode symNode(node, symmetry); - BoardHistory hist; + BoardHistory hist(Rules::DEFAULT_GO); vector moveHistory; bool suc = symNode.getBoardHistoryReachingHere(hist,moveHistory); if(!suc) { @@ -2664,7 +2664,7 @@ int64_t Book::exportToHtmlDir( for(int y = 0; y= 2) { nodeData["id"] = nodeIdx; - nodeData["pla"] = PlayerIO::playerToStringShort(node->pla); + nodeData["pla"] = PlayerIO::playerToStringShort(node->pla, initialRules.isDots); nodeData["syms"] = node->symmetries; nodeData["wl"] = roundDouble(node->thisValuesNotInBook.winLossValue, 100000000); nodeData["sM"] = roundDouble(node->thisValuesNotInBook.scoreMean, 1000000); @@ -2939,7 +2939,7 @@ void Book::saveToFile(const string& fileName) const { } else { nodeData["hash"] = node->hash.toString(); - nodeData["pla"] = PlayerIO::playerToString(node->pla); + nodeData["pla"] = PlayerIO::playerToString(node->pla, initialRules.isDots); nodeData["symmetries"] = node->symmetries; nodeData["winLossValue"] = node->thisValuesNotInBook.winLossValue; nodeData["scoreMean"] = node->thisValuesNotInBook.scoreMean; @@ -3030,7 +3030,7 @@ Book* Book::loadFromFile(const std::string& fileName, int numThreadsForRecompute assertContains(params,"initialBoard"); Board initialBoard = Board::ofJson(params["initialBoard"]); assertContains(params,"initialRules"); - Rules initialRules = Rules::parseRules(params["initialRules"].dump()); + Rules initialRules = Rules::parseRules(params["initialRules"].dump(), params.value("dots", false)); Player initialPla = PlayerIO::parsePlayer(params["initialPla"].get()); int repBound = params["repBound"].get(); diff --git a/cpp/command/analysis.cpp b/cpp/command/analysis.cpp index e4cd37e92..5fa40e97a 100644 --- a/cpp/command/analysis.cpp +++ b/cpp/command/analysis.cpp @@ -161,13 +161,13 @@ int MainCmds::analysis(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); if(humanModelFile != "") { humanEval = Setup::initializeNNEvaluator( humanModelFile,humanModelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); if(!humanEval->requiresSGFMetadata()) { @@ -430,7 +430,8 @@ int MainCmds::analysis(const vector& args) { vector bots; for(int threadIdx = 0; threadIdxsetCopyOfExternalPatternBonusTable(patternBonusTable); bot->setExternalEvalCache(evalCache); threads.push_back(std::thread(analysisLoopProtected,bot,threadIdx)); @@ -702,19 +703,20 @@ int MainCmds::analysis(const vector& args) { { int64_t xBuf; int64_t yBuf; - static const string boardSizeError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN); + static const string boardSizeXError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN_X); + static const string boardSizeYError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN_Y); if(input.find("boardXSize") == input.end()) { - reportErrorForId(rbase.id, "boardXSize", boardSizeError.c_str()); + reportErrorForId(rbase.id, "boardXSize", boardSizeXError); continue; } if(input.find("boardYSize") == input.end()) { - reportErrorForId(rbase.id, "boardYSize", boardSizeError.c_str()); + reportErrorForId(rbase.id, "boardYSize", boardSizeYError); continue; } - if(!parseInteger(input, "boardXSize", xBuf, 2, Board::MAX_LEN, boardSizeError.c_str())) { + if(!parseInteger(input, "boardXSize", xBuf, 2, Board::MAX_LEN_X, boardSizeXError.c_str())) { continue; } - if(!parseInteger(input, "boardYSize", yBuf, 2, Board::MAX_LEN, boardSizeError.c_str())) { + if(!parseInteger(input, "boardYSize", yBuf, 2, Board::MAX_LEN_Y, boardSizeYError.c_str())) { continue; } boardXSize = (int)xBuf; @@ -882,14 +884,14 @@ int MainCmds::analysis(const vector& args) { if(input.find("rules") != input.end()) { if(input["rules"].is_string()) { string s = input["rules"].get(); - if(!Rules::tryParseRules(s,rules)) { + if(!Rules::tryParseRules(s, rules, input.value("dots", false))) { reportErrorForId(rbase.id, "rules", "Could not parse rules: " + s); continue; } } else if(input["rules"].is_object()) { string s = input["rules"].dump(); - if(!Rules::tryParseRules(s,rules)) { + if(!Rules::tryParseRules(s, rules, input.value("dots", false))) { reportErrorForId(rbase.id, "rules", "Could not parse rules: " + s); continue; } @@ -1110,8 +1112,7 @@ int MainCmds::analysis(const vector& args) { continue; } - - Board board(boardXSize,boardYSize); + Board board(boardXSize,boardYSize,rules); for(int i = 0; i& args) { } static void warmStartNNEval(const CompactSgf& sgf, Logger& logger, const SearchParams& params, NNEvaluator* nnEval, Rand& seedRand) { - Board board(sgf.xSize,sgf.ySize); + const Rules rules = Rules(sgf.isDots); + Board board(sgf.xSize,sgf.ySize,rules); Player nextPla = P_BLACK; - BoardHistory hist(board,nextPla,Rules(),0); + BoardHistory hist(board,nextPla,rules,0); SearchParams thisParams = params; thisParams.numThreads = 1; thisParams.maxVisits = 5; @@ -643,7 +644,7 @@ int MainCmds::genconfig(const vector& args, const string& firstCommand) string prompt = "What rules should KataGo use by default for play and analysis?\n" "(chinese, japanese, korean, tromp-taylor, aga, chinese-ogs, new-zealand, bga, stone-scoring, aga-button):\n"; - promptAndParseInput(prompt, [&](const string& line) { configRules = Rules::parseRules(line); }); + promptAndParseInput(prompt, [&](const string& line) { configRules = Rules::parseRules(line, sgf->isDots); }); // TODO: probably incorrect for Dots game? } cout << endl; diff --git a/cpp/command/contribute.cpp b/cpp/command/contribute.cpp index 62fd7bf42..cadbcec70 100644 --- a/cpp/command/contribute.cpp +++ b/cpp/command/contribute.cpp @@ -218,7 +218,7 @@ static void runAndUploadSingleGame( out << "Match: " << botSpecB.botName << " (black) vs " << botSpecW.botName << " (white)" << "\n"; } out << "Rules: " << hist.rules.toJsonString() << "\n"; - out << "Player: " << PlayerIO::playerToString(pla) << "\n"; + out << "Player: " << PlayerIO::playerToString(pla,hist.rules.isDots) << "\n"; out << "Move: " << Location::toString(moveLoc,board) << "\n"; out << "Num Visits: " << search->getRootVisits() << "\n"; if(winLossHist.size() > 0) @@ -242,7 +242,7 @@ static void runAndUploadSingleGame( json ret; // unique to this output ret["gameId"] = gameIdString; - ret["move"] = json::array({PlayerIO::playerToStringShort(pla), Location::toString(moveLoc, board)}); + ret["move"] = json::array({PlayerIO::playerToStringShort(pla, board.isDots()), Location::toString(moveLoc, board)}); ret["blackPlayer"] = botSpecB.botName; ret["whitePlayer"] = botSpecW.botName; @@ -253,7 +253,7 @@ static void runAndUploadSingleGame( json moves = json::array(); for(auto move: hist.moveHistory) { - moves.push_back(json::array({PlayerIO::playerToStringShort(move.pla), Location::toString(move.loc, board)})); + moves.push_back(json::array({PlayerIO::playerToStringShort(move.pla, board.isDots()), Location::toString(move.loc, board)})); } ret["moves"] = moves; @@ -264,11 +264,11 @@ static void runAndUploadSingleGame( Loc loc = Location::getLoc(x, y, initialBoard.x_size); Player locOwner = initialBoard.colors[loc]; if(locOwner != C_EMPTY) - initialStones.push_back(json::array({PlayerIO::playerToStringShort(locOwner), Location::toString(loc, initialBoard)})); + initialStones.push_back(json::array({PlayerIO::playerToStringShort(locOwner, board.isDots()), Location::toString(loc, initialBoard)})); } } ret["initialStones"] = initialStones; - ret["initialPlayer"] = PlayerIO::playerToStringShort(hist.initialPla); + ret["initialPlayer"] = PlayerIO::playerToStringShort(hist.initialPla, board.isDots()); ret["initialTurnNumber"] = hist.initialTurnNumber; // Usual analysis response fields @@ -905,7 +905,7 @@ int MainCmds::contribute(const vector& args) { const bool disableFP16 = false; nnEval = Setup::initializeNNEvaluator( modelName,modelFile,modelInfo.sha256,*userCfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); assert(!nnEval->isNeuralNetLess() || modelFile == "/dev/null"); @@ -919,7 +919,7 @@ int MainCmds::contribute(const vector& args) { const bool disableFP16 = true; nnEval32 = Setup::initializeNNEvaluator( modelName,modelFile,modelInfo.sha256,*userCfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); } diff --git a/cpp/command/demoplay.cpp b/cpp/command/demoplay.cpp index 09b826836..daa4fb049 100644 --- a/cpp/command/demoplay.cpp +++ b/cpp/command/demoplay.cpp @@ -34,7 +34,7 @@ static void writeLine( cout << nnYLen << " "; cout << baseHist.rules.komi << " "; if(baseHist.isGameFinished) { - cout << PlayerIO::playerToString(baseHist.winner) << " "; + cout << PlayerIO::playerToString(baseHist.winner, baseHist.rules.isDots) << " "; cout << baseHist.isResignation << " "; cout << baseHist.finalWhiteMinusBlackScore << " "; } @@ -110,7 +110,7 @@ static void initializeDemoGame(Board& board, BoardHistory& hist, Player& pla, Ra const int size = sizes[rand.nextUInt(sizeFreqs,numSizes)]; - board = Board(size,size); + board = Board(size,size, Rules::DEFAULT_GO); pla = P_BLACK; hist.clear(board,pla,Rules::getTrompTaylorish(),0); bot->setPosition(pla,board,hist); @@ -470,7 +470,7 @@ int MainCmds::demoplay(const vector& args) { ostringstream sout; sout << "genmove null location or illegal move!?!" << "\n"; sout << bot->getRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(pla) << "\n"; + sout << "Pla: " << PlayerIO::playerToString(pla, bot->getRootBoard().isDots()) << "\n"; sout << "MoveLoc: " << Location::toString(moveLoc,bot->getRootBoard()) << "\n"; logger.write(sout.str()); cerr << sout.str() << endl; diff --git a/cpp/command/evalsgf.cpp b/cpp/command/evalsgf.cpp index e561d1545..3aa11a65d 100644 --- a/cpp/command/evalsgf.cpp +++ b/cpp/command/evalsgf.cpp @@ -172,20 +172,21 @@ int MainCmds::evalsgf(const vector& args) { return 1; } - //Parse rules ------------------------------------------------------------------- - Rules defaultRules = Rules::getTrompTaylorish(); - Player perspective = Setup::parseReportAnalysisWinrates(cfg,P_BLACK); - //Parse sgf file and board ------------------------------------------------------------------ std::unique_ptr sgf = CompactSgf::loadFile(sgfFile); - Board board; + //Parse rules ------------------------------------------------------------------- + Rules defaultRules = Rules::getDefaultOrTrompTaylorish(sgf->isDots); + Player perspective = Setup::parseReportAnalysisWinrates(cfg,P_BLACK); + + Board board(defaultRules); Player nextPla; - BoardHistory hist; + BoardHistory hist(defaultRules); auto setUpBoardUsingRules = [&board,&nextPla,&hist,overrideKomi,&sgf,&extraMoves](const Rules& initialRules, int moveNum) { - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + board = hist.initialBoard; vector& moves = sgf->moves; if(!isnan(overrideKomi)) { @@ -219,7 +220,7 @@ int MainCmds::evalsgf(const vector& args) { [](const string& msg) { cout << msg << endl; } ); if(overrideRules != "") { - initialRules = Rules::parseRules(overrideRules); + initialRules = Rules::parseRules(overrideRules, initialRules.isDots); } // Set up once now for error catcihng @@ -379,7 +380,7 @@ int MainCmds::evalsgf(const vector& args) { continue; } - AsyncBot* bot = new AsyncBot(params, nnEval, humanEval, &logger, searchRandSeed); + AsyncBot* bot = new AsyncBot(params, nnEval, humanEval, &logger, searchRandSeed, initialRules); bot->setPosition(nextPla,board,hist); if(hintLoc != "") { @@ -664,8 +665,15 @@ int MainCmds::evalsgf(const vector& args) { int nnXLen = nnEval->getNNXLen(); int nnYLen = nnEval->getNNYLen(); int modelVersion = nnEval->getModelVersion(); - int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + + bool nnEvalIsInDotsMode = nnEval->getDotsGame(); + assert(nnEvalIsInDotsMode == sgf->isDots); + if (nnEvalIsInDotsMode != sgf->isDots) { + cout << "SGF and model mismatch (Go and Dots games)"; + } + + int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, nnEvalIsInDotsMode); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, nnEvalIsInDotsMode); NumpyBuffer binaryInputNCHW(std::vector({1,numSpatialFeatures,nnXLen,nnYLen})); NumpyBuffer globalInputNC(std::vector({1,numGlobalFeatures})); diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index ff04ffec0..5d18e0fe7 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -329,8 +329,8 @@ int MainCmds::genbook(const vector& args) { const bool hasHumanModel = humanModelFile != ""; const SearchParams params = Setup::loadSingleParams(cfg,Setup::SETUP_FOR_GTP,hasHumanModel); - const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN); - const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN); + const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN_X); + const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN_Y); const int repBound = cfg.getInt("repBound",3,1000); const double bonusFileScale = cfg.contains("bonusFileScale") ? cfg.getDouble("bonusFileScale",0.0,1000000.0) : 1.0; @@ -357,9 +357,9 @@ int MainCmds::genbook(const vector& args) { std::map expandBonusByHash; std::map visitsRequiredByHash; std::map branchRequiredByHash; - Board bonusInitialBoard; + Board bonusInitialBoard(rules); Player bonusInitialPla; - bonusInitialBoard = Board(boardSizeX,boardSizeY); + bonusInitialBoard = Board(boardSizeX,boardSizeY,rules); bonusInitialPla = P_BLACK; for(const std::string& bonusFile: bonusFiles) { @@ -444,14 +444,14 @@ int MainCmds::genbook(const vector& args) { if(!bonusInitialBoard.isEqualForTesting(book->getInitialHist().getRecentBoard(0), false, false)) throw StringError( "Book initial board and initial board in bonus sgf file do not match\n" + - Board::toStringSimple(book->getInitialHist().getRecentBoard(0),'\n') + "\n" + - Board::toStringSimple(bonusInitialBoard,'\n') + Board::toStringSimple(book->getInitialHist().getRecentBoard(0)) + "\n" + + Board::toStringSimple(bonusInitialBoard) ); if(bonusInitialPla != book->initialPla) throw StringError( "Book initial player and initial player in bonus sgf file do not match\n" + - PlayerIO::playerToString(book->initialPla) + " book \n" + - PlayerIO::playerToString(bonusInitialPla) + " bonus" + PlayerIO::playerToString(book->initialPla,book->initialRules.isDots) + " book \n" + + PlayerIO::playerToString(bonusInitialPla,book->initialRules.isDots) + " bonus" ); } @@ -570,7 +570,7 @@ int MainCmds::genbook(const vector& args) { Board board = hist.getRecentBoard(0); bool hasAtLeastOneLegalNewMove = false; for(Loc moveLoc = 0; moveLoc < Board::MAX_ARR_SIZE; moveLoc++) { - if(hist.isLegal(board,moveLoc,pla)) { + if(moveLoc != Board::RESIGN_LOC && hist.isLegal(board,moveLoc,pla)) { if(!isReExpansion && constNode.isMoveInBook(moveLoc)) avoidMoveUntilByLoc[moveLoc] = 1; else @@ -1526,15 +1526,15 @@ int MainCmds::writebook(const vector& args) { const bool loadKomiFromCfg = true; Rules rules = Setup::loadSingleRules(cfg,loadKomiFromCfg); - const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN); - const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN); + const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN_X); + const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN_Y); const int repBound = cfg.getInt("repBound",3,1000); std::map bonusByHash; std::map expandBonusByHash; std::map visitsRequiredByHash; std::map branchRequiredByHash; - Board bonusInitialBoard; + Board bonusInitialBoard(rules); Player bonusInitialPla; maybeParseBonusFile( diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 0420bf766..87ea9d19b 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -1,24 +1,29 @@ -#include "../core/global.h" +#include + +#include "../command/commandline.h" #include "../core/commandloop.h" #include "../core/config_parser.h" -#include "../core/fileutils.h" -#include "../core/timer.h" #include "../core/datetime.h" -#include "../core/makedir.h" +#include "../core/fileutils.h" +#include "../core/global.h" #include "../core/test.h" +#include "../core/timer.h" #include "../dataio/sgf.h" -#include "../search/searchnode.h" +#include "../main.h" +#include "../program/play.h" +#include "../program/playutils.h" +#include "../program/setup.h" #include "../search/asyncbot.h" #include "../search/patternbonustable.h" -#include "../program/setup.h" -#include "../program/playutils.h" -#include "../program/play.h" +#include "../search/searchnode.h" #include "../tests/tests.h" -#include "../command/commandline.h" -#include "../main.h" using namespace std; +const string GET_MOVES_COMMAND = "get_moves"; +const string GET_POSITION_COMMAND = "get_position"; +const string GET_BOARDSIZE = "get_boardsize"; + static const vector knownCommands = { //Basic GTP commands "protocol_version", @@ -31,14 +36,17 @@ static const vector knownCommands = { //GTP extension - specify "boardsize X:Y" or "boardsize X Y" for non-square sizes //rectangular_boardsize is an alias for boardsize, intended to make it more evident that we have such support "boardsize", + GET_BOARDSIZE, "rectangular_boardsize", "clear_board", "set_position", + GET_POSITION_COMMAND, "komi", //GTP extension - get KataGo's current komi setting "get_komi", "play", + GET_MOVES_COMMAND, "undo", //GTP extension - specify rules @@ -198,7 +206,7 @@ static bool noWhiteStonesOnBoard(const Board& board) { for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { Loc loc = Location::getLoc(x,y,board.x_size); - if(board.colors[loc] == P_WHITE) + if(board.getColor(loc) == P_WHITE) return false; } } @@ -415,7 +423,7 @@ struct GTPEngine { isGenmoveParams(true), bTimeControls(), wTimeControls(), - initialBoard(), + initialBoard(initialRules), initialPla(P_BLACK), moveHistory(), recentWinLossValues(), @@ -457,9 +465,10 @@ struct GTPEngine { //Specify -1 for the sizes for a default void setOrResetBoardSize(ConfigParser& cfg, Logger& logger, Rand& seedRand, int boardXSize, int boardYSize, bool loggingToStderr) { bool wasDefault = false; + bool isDots = cfg.getBoolOrDefault(DOTS_KEY, false); if(boardXSize == -1 || boardYSize == -1) { - boardXSize = Board::DEFAULT_LEN; - boardYSize = Board::DEFAULT_LEN; + boardXSize = isDots ? Board::DEFAULT_LEN_X_DOTS : Board::DEFAULT_LEN_X; + boardYSize = isDots ? Board::DEFAULT_LEN_Y_DOTS : Board::DEFAULT_LEN_Y; wasDefault = true; } @@ -469,8 +478,8 @@ struct GTPEngine { if(cfg.contains("gtpForceMaxNNSize") && cfg.getBool("gtpForceMaxNNSize")) { defaultRequireExactNNLen = false; - nnXLen = Board::MAX_LEN; - nnYLen = Board::MAX_LEN; + nnXLen = Board::MAX_LEN_X; + nnYLen = Board::MAX_LEN_Y; } //If the neural net is wrongly sized, we need to create or recreate it @@ -550,14 +559,16 @@ struct GTPEngine { else searchRandSeed = Global::uint64ToString(seedRand.nextUInt64()); - bot = new AsyncBot(genmoveParams, nnEval, humanEval, &logger, searchRandSeed); + bot = new AsyncBot(genmoveParams, nnEval, humanEval, &logger, searchRandSeed, currentRules); bot->setCopyOfExternalPatternBonusTable(patternBonusTable); isGenmoveParams = true; - Board board(boardXSize,boardYSize); - Player pla = P_BLACK; + Board board(boardXSize,boardYSize,currentRules); + board.setStartPos(seedRand); + constexpr Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); - vector newMoveHistory; + hist.setInitialTurnNumber(board.numStonesOnBoard()); + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); } @@ -569,7 +580,7 @@ struct GTPEngine { bot->setCopyOfExternalPatternBonusTable(patternBonusTable); } - void setPositionAndRules(Player pla, const Board& board, const BoardHistory& h, const Board& newInitialBoard, Player newInitialPla, const vector newMoveHistory) { + void setPositionAndRules(Player pla, const Board& board, const BoardHistory& h, const Board& newInitialBoard, Player newInitialPla, const vector& newMoveHistory) { BoardHistory hist(h); //Ensure we always have this value correct hist.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); @@ -587,7 +598,8 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); int newXSize = bot->getRootBoard().x_size; int newYSize = bot->getRootBoard().y_size; - Board board(newXSize,newYSize); + Board board(newXSize,newYSize,currentRules); + board.setStartPos(gtpRand); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); vector newMoveHistory; @@ -597,24 +609,33 @@ struct GTPEngine { bool setPosition(const vector& initialStones) { assert(bot->getRootHist().rules == currentRules); - int newXSize = bot->getRootBoard().x_size; - int newYSize = bot->getRootBoard().y_size; - Board board(newXSize,newYSize); - bool suc = board.setStonesFailIfNoLibs(initialStones); - if(!suc) - return false; + const int newXSize = bot->getRootBoard().x_size; + const int newYSize = bot->getRootBoard().y_size; + + vector startPosMoves; + bool startPosIsRandom; + vector remainingPlacementMoves; + const int startPos = + Rules::recognizeStartPos(initialStones, newXSize, newYSize, startPosMoves, startPosIsRandom, &remainingPlacementMoves); + + currentRules.startPos = startPos; + currentRules.startPosIsRandom = startPosIsRandom; + + Board board(newXSize,newYSize,currentRules); + if (!board.setStonesFailIfNoLibs(startPosMoves, true)) return false; + if (!board.setStonesFailIfNoLibs(remainingPlacementMoves)) return false; //Sanity check - for(int i = 0; i newMoveHistory; + hist.setInitialTurnNumber(board.numStonesOnBoard()); // Heuristic to guess at what turn this is + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); return true; @@ -635,30 +656,46 @@ struct GTPEngine { } bool play(Loc loc, Player pla) { - assert(bot->getRootHist().rules == currentRules); - bool suc = bot->makeMove(loc,pla,preventEncore); - if(suc) - moveHistory.push_back(Move(loc,pla)); - return suc; + auto& hist = bot->getRootHist(); + assert(hist.rules == currentRules); + if (hist.isGameFinished) { + return false; + } + + if(bot->makeMove(loc,pla,preventEncore)) { + moveHistory.emplace_back(loc,pla); + return true; + } + return false; } - bool undo() { - if(moveHistory.size() <= 0) + bool undo(const int movesCount) { + auto& hist = bot->getRootHist(); + assert(hist.rules == currentRules); + + if (movesCount == 0) { + return true; + } + + const int newMovesCount = static_cast(moveHistory.size() - movesCount); + if (newMovesCount < 0) { return false; - assert(bot->getRootHist().rules == currentRules); + } - vector moveHistoryCopy = moveHistory; + if(moveHistory.empty()) + return false; - Board undoneBoard = initialBoard; + const vector moveHistoryCopy = moveHistory; + + const Board undoneBoard = initialBoard; BoardHistory undoneHist(undoneBoard,initialPla,currentRules,0); - undoneHist.setInitialTurnNumber(bot->getRootHist().initialTurnNumber); - vector emptyMoveHistory; + undoneHist.setInitialTurnNumber(hist.initialTurnNumber); + const vector emptyMoveHistory; setPositionAndRules(initialPla,undoneBoard,undoneHist,initialBoard,initialPla,emptyMoveHistory); - for(int i = 0; igetRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(pla) << "\n"; + const auto& rootBoard = search->getRootBoard(); + sout << rootBoard << "\n"; + sout << "Pla: " << PlayerIO::playerToString(pla,rootBoard.isDots()) << "\n"; sout << "MoveLoc: " << Location::toString(moveLoc,search->getRootBoard()) << "\n"; logger.write(sout.str()); genmoveTimeSum += genmoveTimer.getSeconds(); @@ -1289,24 +1327,28 @@ struct GTPEngine { if(args.analyzing) { response = "play " + response; } - - return; } - void clearCache() { + void clearCache() const { bot->clearSearch(); bot->clearEvalCache(); nnEval->clearCache(); - if(humanEval != NULL) + if(humanEval != nullptr) humanEval->clearCache(); } + // TODO: Fix handicap placement for Dots game void placeFixedHandicap(int n, string& response, bool& responseIsError) { - int xSize = bot->getRootBoard().x_size; - int ySize = bot->getRootBoard().y_size; - Board board(xSize,ySize); + const int xSize = bot->getRootBoard().x_size; + const int ySize = bot->getRootBoard().y_size; + Board board(xSize,ySize,currentRules); + board.setStartPos(gtpRand); + vector handicapLocs; try { - PlayUtils::placeFixedHandicap(board,n); + handicapLocs = PlayUtils::generateFixedHandicap(board, n); + for (const auto loc : handicapLocs) { + board.setStoneFailIfNoLibs(loc, P_BLACK); + } } catch(const StringError& e) { responseIsError = true; @@ -1325,18 +1367,13 @@ struct GTPEngine { pla = P_WHITE; response = ""; - for(int y = 0; y newMoveHistory; + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); } @@ -1344,9 +1381,9 @@ struct GTPEngine { void placeFreeHandicap(int n, string& response, bool& responseIsError, Rand& rand) { stopAndWait(); - //If asked to place more, we just go ahead and only place up to 30, or a quarter of the board - int xSize = bot->getRootBoard().x_size; - int ySize = bot->getRootBoard().y_size; + // If asked to place more, we just go ahead and only place up to 30, or a quarter of the board + const int xSize = bot->getRootBoard().x_size; + const int ySize = bot->getRootBoard().y_size; int maxHandicap = xSize*ySize / 4; if(maxHandicap > 30) maxHandicap = 30; @@ -1355,11 +1392,12 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); - Board board(xSize,ySize); + Board board(xSize,ySize,currentRules); + board.setStartPos(rand); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); double extraBlackTemperature = 0.25; - PlayUtils::playExtraBlack(bot->getSearchStopAndWait(), n, board, hist, extraBlackTemperature, rand); + const auto& handicapLocs = PlayUtils::playExtraBlack(bot->getSearchStopAndWait(), n, board, hist, extraBlackTemperature, rand); //Also switch the initial player, expecting white should be next. hist.clear(board,P_WHITE,currentRules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); @@ -1367,13 +1405,8 @@ struct GTPEngine { pla = P_WHITE; response = ""; - for(int y = 0; y parseMovesSequence(const vector& pieces, const Board& board, bool passIsAllowed, vector& movesToPlay) { + optional response = std::nullopt; + + auto renderPieces = [pieces](const int& index) { + const vector subvector(pieces.begin(), pieces.begin() + index + 1); + return " (" + Global::concat(subvector, " ") + ")"; + }; + + for (int pieceInd = 0; pieceInd < pieces.size(); pieceInd += 2) { + Player pla; + Loc loc; + if(!PlayerIO::tryParsePlayer(pieces[pieceInd], pla)) { + response = "Could not parse color: '" + pieces[pieceInd] + "'" + renderPieces(pieceInd); + break; + } + + const int locIndex = pieceInd + 1; + if (locIndex >= pieces.size()) { + response = "Expected location after color: '" + pieces[pieceInd] + "'" + renderPieces(pieceInd); + break; + } + + if (!tryParseLoc(pieces[locIndex], board, loc)) { + response = "Could not parse location: '" + pieces[locIndex] + "'" + renderPieces(locIndex); + break; + } + + if (loc == Board::PASS_LOC && !passIsAllowed) { + response = Location::toString(loc, board) + " is disallowed" + renderPieces(locIndex); + break; + } + + movesToPlay.emplace_back(loc, pla); + } + + return response; +} + +static string printMoves(const vector& moves, const Board& board) { + std::ostringstream builder; + for (const auto move : moves) { + builder << PlayerIO::playerToStringShort(move.pla, board.isDots()) << " " << Location::toString(move.loc, board) << " "; + } + return builder.str(); +} int MainCmds::gtp(const vector& args) { Board::initHash(); @@ -1929,9 +2007,9 @@ int MainCmds::gtp(const vector& args) { //Defaults to 7.5 komi, gtp will generally override this const bool loadKomiFromCfg = false; Rules initialRules = Setup::loadSingleRules(cfg,loadKomiFromCfg); - logger.write("Using " + initialRules.toStringNoKomiMaybeNice() + " rules initially, unless GTP/GUI overrides this"); + logger.write("Using " + initialRules.toStringNoSgfDefinedPropertiesMaybeNice() + " rules initially, unless GTP/GUI overrides this"); if(startupPrintMessageToStderr && !logger.isLoggingToStderr()) { - cerr << "Using " + initialRules.toStringNoKomiMaybeNice() + " rules initially, unless GTP/GUI overrides this" << endl; + cerr << "Using " + initialRules.toStringNoSgfDefinedPropertiesMaybeNice() + " rules initially, unless GTP/GUI overrides this" << endl; } bool isForcingKomi = false; float forcedKomi = 0; @@ -2002,6 +2080,7 @@ int MainCmds::gtp(const vector& args) { const double initialDelayMoveScale = cfg.contains("delayMoveScale") ? cfg.getDouble("delayMoveScale",0.0,10000.0) : 0.0; const double initialDelayMoveMax = cfg.contains("delayMoveMax") ? cfg.getDouble("delayMoveMax",0.0,1000000.0) : 1000000.0; + int defaultBoardXSize = -1; int defaultBoardYSize = -1; Setup::loadDefaultBoardXYSize(cfg,logger,defaultBoardXSize,defaultBoardYSize); @@ -2225,7 +2304,7 @@ int MainCmds::gtp(const vector& args) { } else if(command == "name") { - response = "KataGo"; + response = "KataGoDots"; } else if(command == "version") { @@ -2301,15 +2380,26 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "unacceptable size"; } - else if(newXSize > Board::MAX_LEN || newYSize > Board::MAX_LEN) { + else if(newXSize > Board::MAX_LEN_X) { responseIsError = true; - response = Global::strprintf("unacceptable size (Board::MAX_LEN is %d, consider increasing and recompiling)",(int)Board::MAX_LEN); + response = Global::strprintf("unacceptable size (Board::MAX_LEN_X is %d, consider increasing and recompiling)",(int)Board::MAX_LEN_X); + } + else if(newYSize > Board::MAX_LEN_Y) { + responseIsError = true; + response = Global::strprintf("unacceptable size (Board::MAX_LEN_Y is %d, consider increasing and recompiling)",(int)Board::MAX_LEN_Y); } else { engine->setOrResetBoardSize(cfg,logger,seedRand,newXSize,newYSize,logger.isLoggingToStderr()); } } + else if (command == GET_BOARDSIZE) { + const auto& rootBoard = engine->bot->getRootBoard(); + response = Global::intToString(rootBoard.x_size) + (rootBoard.x_size == rootBoard.y_size + ? "" + : ":" + Global::intToString(rootBoard.y_size)); + } + else if(command == "clear_board") { maybeSaveAvoidPatterns(false); if(autoAvoidPatterns && shouldReloadAutoAvoidPatterns) { @@ -2362,7 +2452,8 @@ int MainCmds::gtp(const vector& args) { bool parseSuccess = false; Rules newRules; try { - newRules = Rules::parseRulesWithoutKomi(rest,engine->getCurrentRules().komi); + Rules currentRules = engine->getCurrentRules(); + newRules = Rules::parseRulesWithoutKomi(rest, currentRules.komi, currentRules.isDots); parseSuccess = true; } catch(const StringError& err) { @@ -2376,9 +2467,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } @@ -2406,9 +2497,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } } @@ -2423,19 +2514,19 @@ int MainCmds::gtp(const vector& args) { else { string s = Global::toLower(Global::trim(pieces[0])); if(s == "chinese") { - newRules = Rules::parseRulesWithoutKomi("chinese-kgs",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("chinese-kgs", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "aga") { - newRules = Rules::parseRulesWithoutKomi("aga",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("aga", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "new_zealand") { - newRules = Rules::parseRulesWithoutKomi("new_zealand",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("new_zealand", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "japanese") { - newRules = Rules::parseRulesWithoutKomi("japanese",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("japanese", engine->getCurrentRules().komi); parseSuccess = true; } else { @@ -2450,9 +2541,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } @@ -2910,77 +3001,71 @@ int MainCmds::gtp(const vector& args) { } else if(command == "play") { - Player pla; - Loc loc; - if(pieces.size() != 2) { - responseIsError = true; - response = "Expected two arguments for play but got '" + Global::concat(pieces," ") + "'"; - } - else if(!PlayerIO::tryParsePlayer(pieces[0],pla)) { - responseIsError = true; - response = "Could not parse color: '" + pieces[0] + "'"; - } - else if(!tryParseLoc(pieces[1],engine->bot->getRootBoard(),loc)) { - responseIsError = true; - response = "Could not parse vertex: '" + pieces[1] + "'"; - } - else { - bool suc = engine->play(loc,pla); - if(!suc) { - responseIsError = true; - response = "illegal move"; - } - maybeStartPondering = true; - } - } + vector movesToPlay; + const Board& rootBoard = engine->bot->getRootBoard(); - else if(command == "set_position") { - if(pieces.size() % 2 != 0) { - responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "'"; - } - else { - vector initialStones; - for(int i = 0; ibot->getRootBoard(),loc)) { + if (auto parseError = parseMovesSequence(pieces, rootBoard, true, movesToPlay); parseError == std::nullopt) { + for (int moveInd = 0; moveInd < movesToPlay.size(); moveInd++) { + if (Move move = movesToPlay[moveInd]; !engine->play(move.loc, move.pla)) { responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "': "; - response += "could not parse vertex: '" + pieces[i+1] + "'"; + response = "Illegal move " + PlayerIO::playerToString(move.pla, rootBoard.isDots()) + " " + Location::toString(move.loc, rootBoard); + assert(engine->undo(moveInd)); // Rollback already placed moves break; } - else if(loc == Board::PASS_LOC) { - responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "': "; - response += "could not parse vertex: '" + pieces[i+1] + "'"; - break; - } - initialStones.push_back(Move(loc,pla)); } - if(!responseIsError) { - maybeSaveAvoidPatterns(false); - bool suc = engine->setPosition(initialStones); - if(!suc) { - responseIsError = true; - response = "Illegal stone placements - overlapping stones or stones with no liberties?"; - } - maybeStartPondering = false; + + if (!responseIsError) { + maybeStartPondering = true; + } + } else { + responseIsError = true; + response = parseError.value(); + } + } + + else if(command == "set_position") { + vector initialStones; + const Board& rootBoard = engine->bot->getRootBoard(); + + if (auto parseError = parseMovesSequence(pieces, rootBoard, false, initialStones); parseError == std::nullopt) { + maybeSaveAvoidPatterns(false); + if(!engine->setPosition(initialStones)) { + responseIsError = true; + response = "Illegal stone placements - overlapping stones or stones with no liberties?"; } + maybeStartPondering = false; + } else { + responseIsError = true; + response = parseError.value(); } } + else if (command == GET_MOVES_COMMAND) { + const BoardHistory& history = engine->bot->getRootHist(); + response = printMoves(history.moveHistory, engine->bot->getRootBoard()); + } + + else if (command == GET_POSITION_COMMAND) { + const auto& rootBoard = engine->bot->getRootBoard(); + // TODO: fix for handicap start moves + response = printMoves(rootBoard.start_pos_moves, rootBoard); + } + else if(command == "undo") { - bool suc = engine->undo(); - if(!suc) { - responseIsError = true; - response = "cannot undo"; + int undoCount = 1; + if (!pieces.empty()) { + if (!Global::tryStringToInt(pieces[0], undoCount) || undoCount < 0) { + responseIsError = true; + response = "Expected nonnegative integer for undo count"; + } + } + + if (!responseIsError) { + if (!engine->undo(undoCount)) { + responseIsError = true; + response = "cannot undo"; + break; + } } } @@ -3129,7 +3214,7 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "Number of handicap stones less than 2: '" + pieces[0] + "'"; } - else if(!engine->bot->getRootBoard().isEmpty()) { + else if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } @@ -3153,7 +3238,7 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "Number of handicap stones less than 2: '" + pieces[0] + "'"; } - else if(!engine->bot->getRootBoard().isEmpty()) { + else if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } @@ -3164,23 +3249,25 @@ int MainCmds::gtp(const vector& args) { } else if(command == "set_free_handicap") { - if(!engine->bot->getRootBoard().isEmpty()) { + if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } else { vector locs; - int xSize = engine->bot->getRootBoard().x_size; - int ySize = engine->bot->getRootBoard().y_size; - Board board(xSize,ySize); - for(int i = 0; ibot->getRootBoard(); + int xSize = rootBoard->x_size; + int ySize = rootBoard->y_size; + Board board(xSize,ySize,rootBoard->rules); + board.setStartPos(seedRand); + for(const auto & piece : pieces) { Loc loc; - bool suc = tryParseLoc(pieces[i],board,loc); + bool suc = tryParseLoc(piece,board,loc); if(!suc || loc == Board::PASS_LOC) { responseIsError = true; - response = "Invalid handicap location: " + pieces[i]; + response = "Invalid handicap location: " + piece; } - locs.push_back(Move(loc,P_BLACK)); + locs.emplace_back(loc,P_BLACK); } bool suc = board.setStonesFailIfNoLibs(locs); if(!suc) { @@ -3205,12 +3292,12 @@ int MainCmds::gtp(const vector& args) { double finalWhiteMinusBlackScore = 0.0; engine->computeAnticipatedWinnerAndScore(winner,finalWhiteMinusBlackScore); - if(winner == C_EMPTY) + if (winner == C_EMPTY) response = "0"; - else if(winner == C_BLACK) - response = "B+" + Global::strprintf("%.1f",-finalWhiteMinusBlackScore); - else if(winner == C_WHITE) - response = "W+" + Global::strprintf("%.1f",finalWhiteMinusBlackScore); + else if (winner == C_BLACK || winner == C_WHITE) { + double finalScore = winner == C_BLACK ? -finalWhiteMinusBlackScore : finalWhiteMinusBlackScore; + response = PlayerIO::playerToStringShort(winner, engine->currentRules.isDots) + "+" + Global::strprintf("%.1f", finalScore); + } else ASSERT_UNREACHABLE; } @@ -3243,7 +3330,7 @@ int MainCmds::gtp(const vector& args) { for(int y = 0; y& args) { for(int y = 0; y& args) { response = "Expected one or two arguments for loadsgf but got '" + Global::concat(pieces," ") + "'"; } else { - string filename = pieces[0]; + const string& filename = pieces[0]; bool parseFailed = false; bool moveNumberSpecified = false; int moveNumber = 0; @@ -3295,7 +3382,6 @@ int MainCmds::gtp(const vector& args) { else { Board sgfInitialBoard; Player sgfInitialNextPla; - BoardHistory sgfInitialHist; Rules sgfRules; Board sgfBoard; Player sgfNextPla; @@ -3315,7 +3401,7 @@ int MainCmds::gtp(const vector& args) { engine->getCurrentRules(), //Use current rules as default [&logger](const string& msg) { logger.write(msg); cerr << msg << endl; } ); - if(engine->nnEval != NULL) { + if(engine->nnEval != nullptr) { bool rulesWereSupported; Rules supportedRules = engine->nnEval->getSupportedRules(sgfRules,rulesWereSupported); if(!rulesWereSupported) { @@ -3345,7 +3431,8 @@ int MainCmds::gtp(const vector& args) { } } - sgf->setupInitialBoardAndHist(sgfRules, sgfInitialBoard, sgfInitialNextPla, sgfInitialHist); + BoardHistory sgfInitialHist = sgf->setupInitialBoardAndHist(sgfRules, sgfInitialNextPla); + sgfInitialBoard = sgfInitialHist.initialBoard; sgfInitialHist.setInitialTurnNumber(sgfInitialBoard.numStonesOnBoard()); //Should give more accurate temperaure and time control behavior sgfBoard = sgfInitialBoard; sgfNextPla = sgfInitialNextPla; @@ -3574,59 +3661,65 @@ int MainCmds::gtp(const vector& args) { } if(parsed) { - engine->stopAndWait(); - - int boardSizeX = engine->bot->getRootBoard().x_size; - int boardSizeY = engine->bot->getRootBoard().y_size; - if(boardSizeX != boardSizeY) { + const auto& rootBoard = engine->bot->getRootBoard(); + if (rootBoard.rules.isDots) { responseIsError = true; - response = - "Current board size is " + Global::intToString(boardSizeX) + "x" + Global::intToString(boardSizeY) + - ", no built-in benchmarks for rectangular boards"; - } - else { - std::unique_ptr sgf = nullptr; - try { - string sgfData = TestCommon::getBenchmarkSGFData(boardSizeX); - sgf = CompactSgf::parse(sgfData); - } - catch(const StringError& e) { + response = "Not yet supported for Dots"; + } else { + engine->stopAndWait(); + int boardSizeX = rootBoard.x_size; + int boardSizeY = rootBoard.y_size; + if(boardSizeX != boardSizeY) { responseIsError = true; - response = e.what(); + response = + "Current board size is " + Global::intToString(boardSizeX) + "x" + Global::intToString(boardSizeY) + + ", no built-in benchmarks for rectangular boards"; } - if(sgf != nullptr) { - const PlayUtils::BenchmarkResults* baseline = NULL; - const double secondsPerGameMove = 1.0; - const bool printElo = false; - SearchParams params = engine->getGenmoveParams(); - params.maxTime = 1.0e20; - params.maxPlayouts = ((int64_t)1) << 50; - params.maxVisits = numVisits; - //Make sure the "equals" for GTP is printed out prior to the benchmark line - printGTPResponseHeader(); - + else { + std::unique_ptr sgf = nullptr; try { - PlayUtils::BenchmarkResults results = PlayUtils::benchmarkSearchOnPositionsAndPrint( - params, - *sgf, - 10, - engine->nnEval, - baseline, - secondsPerGameMove, - printElo - ); - (void)results; + string sgfData = TestCommon::getBenchmarkSGFData(boardSizeX); + sgf = CompactSgf::parse(sgfData); } catch(const StringError& e) { responseIsError = true; response = e.what(); - sgf = nullptr; } if(sgf != nullptr) { - //Act of benchmarking will write to stdout with a newline at the end, so we just need one more newline ourselves - //to complete GTP protocol. - suppressResponse = true; - cout << endl; + const PlayUtils::BenchmarkResults* baseline = nullptr; + const double secondsPerGameMove = 1.0; + const bool printElo = false; + SearchParams params = engine->getGenmoveParams(); + params.maxTime = 1.0e20; + params.maxPlayouts = ((int64_t)1) << 50; + params.maxVisits = numVisits; + //Make sure the "equals" for GTP is printed out prior to the benchmark line + printGTPResponseHeader(); + + try { + PlayUtils::BenchmarkResults results = PlayUtils::benchmarkSearchOnPositionsAndPrint( + params, + *sgf, + 10, + engine->nnEval, + baseline, + secondsPerGameMove, + printElo + ); + (void)results; + } + catch(const StringError& e) { + responseIsError = true; + response = e.what(); + + sgf = nullptr; + } + if(sgf != nullptr) { + //Act of benchmarking will write to stdout with a newline at the end, so we just need one more newline ourselves + //to complete GTP protocol. + suppressResponse = true; + cout << endl; + } } } } diff --git a/cpp/command/match.cpp b/cpp/command/match.cpp index 9d7fbda90..d235634b4 100644 --- a/cpp/command/match.cpp +++ b/cpp/command/match.cpp @@ -108,7 +108,7 @@ int MainCmds::match(const vector& args) { } if(cfg.contains("extraPairs")) { - std::vector> pairs = cfg.getNonNegativeIntDashedPairs("extraPairs",0,numBots-1); + std::vector> pairs = cfg.getNonNegativeIntDashedPairs("extraPairs", 0, numBots - 1, numBots - 1); for(const std::pair& pair: pairs) { int p0 = pair.first; int p1 = pair.second; diff --git a/cpp/command/misc.cpp b/cpp/command/misc.cpp index 001bc5b3b..255a61498 100644 --- a/cpp/command/misc.cpp +++ b/cpp/command/misc.cpp @@ -209,9 +209,9 @@ int MainCmds::evalrandominits(const vector& args) { Rand gameRand; while(true) { - Board board(19,19); - Player pla = P_BLACK; Rules rules = Rules::parseRules("japanese"); + Board board(19,19,rules); + Player pla = P_BLACK; BoardHistory hist(board,pla,rules,0); int numInitialMovesToPlay = (int)gameRand.nextUInt(200); double temperature = 1.0; @@ -224,6 +224,7 @@ int MainCmds::evalrandominits(const vector& args) { pla = getOpp(pla); hist.endGameIfAllPassAlive(board); + hist.endGameIfNoLegalMoves(board); if(hist.isGameFinished) break; } @@ -322,11 +323,10 @@ int MainCmds::searchentropyanalysis(const vector& args) { std::unique_ptr sgfObj = CompactSgf::parse(sgf); for(int turnIdx = 0; turnIdx < sgfObj->moves.size(); turnIdx++) { - Board board; Player pla; - BoardHistory hist; - Rules initialRules; - sgfObj->setupInitialBoardAndHist(initialRules, board, pla, hist); + Rules rules = sgfObj->getRulesOrFailAllowUnspecified(Rules::getSimpleTerritory()); + BoardHistory hist = sgfObj->setupInitialBoardAndHist(rules, pla); + Board& board = hist.initialBoard; for(int i = 0; i < turnIdx; i++) { Loc moveLoc = sgfObj->moves[i].loc; diff --git a/cpp/command/runtests.cpp b/cpp/command/runtests.cpp index 75bb9760c..62e2ef83b 100644 --- a/cpp/command/runtests.cpp +++ b/cpp/command/runtests.cpp @@ -30,6 +30,7 @@ int MainCmds::runtests(const vector& args) { Board::initHash(); ScoreValue::initTables(); + Global::runTests(); BSearch::runTests(); Rand::runTests(); DateTime::runTests(); @@ -38,6 +39,16 @@ int MainCmds::runtests(const vector& args) { Base64::runTests(); ThreadTest::runTests(); + Tests::runDotsFieldTests(); + Tests::runDotsGroundingTests(); + Tests::runDotsBoardHistoryGroundingTests(); + Tests::runDotsPosHashTests(); + Tests::runDotsStartPosTests(); + + Tests::runDotsSymmetryTests(); + Tests::runDotsOwnershipTests(); + Tests::runDotsCapturingTests(); + Tests::runBoardIOTests(); Tests::runBoardBasicTests(); @@ -47,7 +58,6 @@ int MainCmds::runtests(const vector& args) { Tests::runBoardUndoTest(); Tests::runBoardHandicapTest(); - Tests::runBoardStressTest(); Tests::runSgfTests(); Tests::runBasicSymmetryTests(); @@ -55,6 +65,9 @@ int MainCmds::runtests(const vector& args) { Tests::runSymmetryDifferenceTests(); Tests::runBoardReplayTest(); + Tests::runDotsStressTests(); + Tests::runBoardStressTest(); + ScoreValue::freeTables(); Tests::runInlineConfigTests(); @@ -434,7 +447,7 @@ int MainCmds::runnnevalcanarytests(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } @@ -493,7 +506,7 @@ int MainCmds::runbeginsearchspeedtest(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - Board::MAX_LEN,Board::MAX_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + Board::MAX_LEN_X,Board::MAX_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } @@ -617,7 +630,7 @@ int MainCmds::runownershipspeedtest(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - Board::MAX_LEN,Board::MAX_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + Board::MAX_LEN_X,Board::MAX_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index b03fb8c7d..0d771c4d8 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -93,10 +93,14 @@ int MainCmds::selfplay(const vector& args) { //Width and height of the board to use when writing data, typically 19 const int dataBoardLen = cfg.getInt("dataBoardLen",3,Board::MAX_LEN); + const int dataBoardLenX = cfg.contains(DATA_LEN_X_KEY) ? cfg.getInt(DATA_LEN_X_KEY,3,Board::MAX_LEN_X) : dataBoardLen; + const int dataBoardLenY = cfg.contains(DATA_LEN_Y_KEY) ? cfg.getInt(DATA_LEN_Y_KEY,3,Board::MAX_LEN_Y) : dataBoardLen; + + const bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); const int inputsVersion = cfg.contains("inputsVersion") ? cfg.getInt("inputsVersion",0,10000) : - NNModelVersion::getInputsVersion(NNModelVersion::defaultModelVersion); + NNModelVersion::getInputsVersion(NNModelVersion::defaultModelVersion, dotsGame); //Max number of games that we will allow to be queued up and not written out const int maxDataQueueSize = cfg.getInt("maxDataQueueSize",1,1000000); const int maxRowsPerTrainFile = cfg.getInt("maxRowsPerTrainFile",1,100000000); @@ -135,9 +139,9 @@ int MainCmds::selfplay(const vector& args) { //Returns true if a new net was loaded. auto loadLatestNeuralNetIntoManager = - [inputsVersion,&manager,maxRowsPerTrainFile,firstFileRandMinProp,dataBoardLen, + [inputsVersion,&manager,maxRowsPerTrainFile,firstFileRandMinProp,dataBoardLenX,dataBoardLenY, &modelsDir,&outputDir,&logger,&cfg,numGameThreads, - minBoardXSizeUsed,maxBoardXSizeUsed,minBoardYSizeUsed,maxBoardYSizeUsed](const string* lastNetName) -> bool { + minBoardXSizeUsed,maxBoardXSizeUsed,minBoardYSizeUsed,maxBoardYSizeUsed,dotsGame](const string* lastNetName) -> bool { string modelName; string modelFile; @@ -212,9 +216,9 @@ int MainCmds::selfplay(const vector& args) { //Note that this inputsVersion passed here is NOT necessarily the same as the one used in the neural net self play, it //simply controls the input feature version for the written data - TrainingDataWriter* tdataWriter = new TrainingDataWriter( - tdataOutputDir, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); - ofstream* sgfOut = NULL; + auto tdataWriter = new TrainingDataWriter( + tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLenX, dataBoardLenY, Global::uint64ToHexString(rand.nextUInt64()), 1, dotsGame); + ofstream* sgfOut = nullptr; if(sgfOutputDir.length() > 0) { sgfOut = new ofstream(); FileUtils::open(*sgfOut, sgfOutputDir + "/" + Global::uint64ToHexString(rand.nextUInt64()) + ".sgfs"); diff --git a/cpp/command/startposes.cpp b/cpp/command/startposes.cpp index d96cb61e8..4e739492e 100644 --- a/cpp/command/startposes.cpp +++ b/cpp/command/startposes.cpp @@ -436,11 +436,11 @@ int MainCmds::samplesgfs(const vector& args) { else { string fileName = sgf.fileName; CompactSgf compactSgf(sgf); - Board board; + Player nextPla; - BoardHistory hist; Rules rules = compactSgf.getRulesOrFailAllowUnspecified(Rules::getSimpleTerritory()); - compactSgf.setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = compactSgf.setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; if(valueFluctuationMakeKomiFair) { Rand rand; @@ -502,7 +502,7 @@ int MainCmds::samplesgfs(const vector& args) { //Only log on errors that aren't simply due to ko rules, but quit out regardless suc = hist.makeBoardMoveTolerant(board,sgfMoves[m].loc,sgfMoves[m].pla,preventEncore); if(!suc) - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size, board.isDots())); break; } hist.makeBoardMoveAssumeLegal(board,sgfMoves[m].loc,sgfMoves[m].pla,NULL,preventEncore); @@ -1386,10 +1386,10 @@ int MainCmds::dataminesgfs(const vector& args) { //Don't use the SGF rules - randomize them for a bit more entropy Rules rules = gameInit->createRules(); - Board board; Player nextPla; - BoardHistory hist; - sgf.setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = sgf.setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; + if(!gameInit->isAllowedBSize(board.x_size,board.y_size)) { numFilteredSgfs.fetch_add(1); return; @@ -1469,7 +1469,7 @@ int MainCmds::dataminesgfs(const vector& args) { //Only log on errors that aren't simply due to ko rules, but quit out regardless suc = hist.makeBoardMoveTolerant(board,sgfMoves[m].loc,sgfMoves[m].pla,preventEncore); if(!suc) - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size, board.isDots())); break; } hist.makeBoardMoveAssumeLegal(board,sgfMoves[m].loc,sgfMoves[m].pla,NULL,preventEncore); @@ -1662,7 +1662,7 @@ int MainCmds::dataminesgfs(const vector& args) { for(int i = 0; i& args) { if(bot != NULL || !checkLegality) { cout << "StartPos: " << s << "/" << startPoses.size() << "\n"; - cout << "Next pla: " << PlayerIO::playerToString(pla) << "\n"; + cout << "Next pla: " << PlayerIO::playerToString(pla, board.isDots()) << "\n"; cout << "Weight: " << startPos.weight << "\n"; cout << "TrainingWeight: " << startPos.trainingWeight << "\n"; - cout << "StartPosInitialNextPla: " << PlayerIO::playerToString(startPos.nextPla) << "\n"; + cout << "StartPosInitialNextPla: " << PlayerIO::playerToString(startPos.nextPla, board.isDots()) << "\n"; cout << "StartPosMoves: "; for(int i = 0; i<(int)startPos.moves.size(); i++) cout << (startPos.moves[i].pla == P_WHITE ? "w" : "b") << Location::toString(startPos.moves[i].loc,board) << " "; diff --git a/cpp/command/writetrainingdata.cpp b/cpp/command/writetrainingdata.cpp index 47fba131f..ce988fa38 100644 --- a/cpp/command/writetrainingdata.cpp +++ b/cpp/command/writetrainingdata.cpp @@ -681,10 +681,12 @@ int MainCmds::writetrainingdata(const vector& args) { const int numTotalThreads = numWorkerThreads * numSearchThreads; const int dataBoardLen = cfg.getInt("dataBoardLen",3,Board::MAX_LEN); + const int dataBoardLenX = cfg.contains(DATA_LEN_X_KEY) ? cfg.getInt(DATA_LEN_X_KEY,3,Board::MAX_LEN_X) : dataBoardLen; + const int dataBoardLenY = cfg.contains(DATA_LEN_Y_KEY) ? cfg.getInt(DATA_LEN_Y_KEY,3,Board::MAX_LEN_Y) : dataBoardLen; const int maxApproxRowsPerTrainFile = cfg.getInt("maxApproxRowsPerTrainFile",1,100000000); const std::vector> allowedBoardSizes = - cfg.getNonNegativeIntDashedPairs("allowedBoardSizes", 2, Board::MAX_LEN); + cfg.getNonNegativeIntDashedPairs("allowedBoardSizes", 2, Board::MAX_LEN_X, Board::MAX_LEN_Y); if(dataBoardLen > Board::MAX_LEN) throw StringError("dataBoardLen > maximum board len, must recompile to increase"); @@ -712,7 +714,7 @@ int MainCmds::writetrainingdata(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( nnModelFile,nnModelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); } @@ -814,8 +816,8 @@ int MainCmds::writetrainingdata(const vector& args) { (maxApproxRowsPerTrainFile * 4/3 + Board::MAX_PLAY_SIZE * 2 + 100), numBinaryChannels, numGlobalChannels, - dataBoardLen, - dataBoardLen, + dataBoardLenX, + dataBoardLenY, hasMetadataInput ) ); @@ -868,12 +870,13 @@ int MainCmds::writetrainingdata(const vector& args) { return; } - if(xySize.x > dataBoardLen || xySize.y > dataBoardLen) { + if(xySize.x > dataBoardLenX || xySize.y > dataBoardLenY) { logger.write( - "SGF board size > dataBoardLen in " + fileName + ":" + "SGF board sizeX > dataBoardLenX or sizeY > dataBoardLenY in " + fileName + ":" + " " + Global::intToString(xySize.x) + " " + Global::intToString(xySize.y) - + " " + Global::intToString(dataBoardLen) + + " " + Global::intToString(dataBoardLenX) + + " " + Global::intToString(dataBoardLenY) ); reportSgfDone(false,"SGFGreaterThanDataBoardLen"); return; @@ -1372,11 +1375,13 @@ int MainCmds::writetrainingdata(const vector& args) { // No friendly pass since we want to complete consistent with strict rules rules.friendlyPassOk = false; - Board board; + // TODO: Fix construction of board and hist + Board board(Rules::DEFAULT_GO); Player nextPla; - BoardHistory hist; + BoardHistory hist(Rules::DEFAULT_GO); try { - sgf->setupInitialBoardAndHist(rules, board, nextPla, hist); + hist = sgf->setupInitialBoardAndHist(rules, nextPla); + board = hist.initialBoard; } catch(const StringError& e) { logger.write("Bad initial setup in sgf " + fileName + " " + e.what()); @@ -1878,7 +1883,7 @@ int MainCmds::writetrainingdata(const vector& args) { } bool suc = hist.isLegal(board,move.loc,move.pla); if(!suc) { - logger.write("Illegal move near start in " + fileName + " move " + Location::toString(move.loc, board.x_size, board.y_size) + sizeStr); + logger.write("Illegal move near start in " + fileName + " move " + Location::toString(move.loc, board) + sizeStr); reportSgfDone(false,"MovesIllegalMoveNearStart"); return; } @@ -1943,7 +1948,7 @@ int MainCmds::writetrainingdata(const vector& args) { } bool suc = hist.isLegal(board,move.loc,move.pla); if(!suc) { - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(move.loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(move.loc, board)); reportSgfDone(false,"MovesIllegal"); return; } diff --git a/cpp/configs/gtp_example_dots.cfg b/cpp/configs/gtp_example_dots.cfg new file mode 100644 index 000000000..be47f7387 --- /dev/null +++ b/cpp/configs/gtp_example_dots.cfg @@ -0,0 +1,744 @@ +# Configuration for KataGo C++ GTP engine + +# Run the program using: `./katago.exe gtp` + +# In this example config, when a parameter is given as a commented out value, +# that value also is the default value, unless described otherwise. You can +# uncomment it (remove the pound sign) and change it if you want. + +# =========================================================================== +# Running on an online server or in a real tournament or match +# =========================================================================== +# If you plan to run online or in a tournament, read through the "Rules" +# section below for proper handling of komi, handicaps, end-of-game cleanup, +# and other details. + +# =========================================================================== +# Notes about performance and memory usage +# =========================================================================== +# Important: For good performance, you will very likely want to tune the +# "numSearchThreads" parameter in the Search limits section below! Run +# "./katago benchmark" to test KataGo and to suggest a reasonable value +# of this parameter. + +# For multi-GPU systems, read "OpenCL GPU settings" or "CUDA GPU settings". +# +# When using OpenCL, verify that KataGo picks the correct device! Some systems +# may have both an Intel CPU OpenCL and GPU OpenCL. If # KataGo picks the wrong +# one, correct this by specifying "openclGpuToUse". +# +# Consider adjusting "maxVisits", "ponderingEnabled", "resignThreshold", and +# other parameters depending on your intended usage. + +# =========================================================================== +# Command-line usage +# =========================================================================== +# All of the below values may be set or overridden via command-line arguments: +# +# -override-config KEY=VALUE,KEY=VALUE,... + +# =========================================================================== +# Logs and files +# =========================================================================== +# This section defines where and what logging information is produced. + +# Each run of KataGo will log to a separate file in this dir. +# This is the default. +logDir = gtp_logs +# Uncomment and specify this instead of logDir to write separate dated subdirs +# logDirDated = gtp_logs +# Uncomment and specify this instead of logDir to log to only a single file +# logFile = gtp.log + +# Logging options +logAllGTPCommunication = true +logSearchInfo = true +logSearchInfoForChosenMove = false +logToStderr = false + +# KataGo will display some info to stderr on GTP startup +# Uncomment the next line and set it to false to suppress that and remain silent +# startupPrintMessageToStderr = true + +# Write information to stderr, for use in things like malkovich chat to OGS. +# ogsChatToStderr = false + +# Uncomment and set this to a directory to override where openCLTuner files +# and other cached data is written. By default it saves into a subdir of the +# current directory on windows, and a subdir of ~/.katago on Linux. +# homeDataDir = PATH_TO_DIRECTORY + +# =========================================================================== +# Analysis +# =========================================================================== +# This section configures analysis settings. +# +# The maximum number of moves after the first move displayed in variations +# from analysis commands like kata-analyze or lz-analyze. +# analysisPVLen = 15 + +# Report winrates for chat and analysis as (BLACK|WHITE|SIDETOMOVE). +# Most GUIs and analysis tools will expect SIDETOMOVE. +# reportAnalysisWinratesAs = SIDETOMOVE + +# Extra noise for wider exploration. Large values will force KataGo to +# analyze a greater variety of moves than it normally would. +# An extreme value like 1 distributes playouts across every move on the board, +# even very bad moves. +# Affects analysis only, does not affect play. +# analysisWideRootNoise = 0.04 + +# Try to limit the effect of possible bad or bogus move sequences in the +# history leading to this position from affecting KataGo's move predictions. +# analysisIgnorePreRootHistory = true + +# =========================================================================== +# Rules +# =========================================================================== +# This section configures the scoring and playing rules. Rules can also be +# changed mid-run by issuing custom GTP commands. +# +# See https://lightvector.github.io/KataGo/rules.html for rules details. +# +# See https://github.com/lightvector/KataGo/blob/master/docs/GTP_Extensions.md +# for GTP commands. + +# Specify the rules as a string. +# Some legal values include: +# chinese, japanese, korean, aga, chinese-ogs, new-zealand, stone-scoring, +# ancient-territory, bga, aga-button +# +# For some human rulesets that require complex adjudication in tricky cases +# (e.g. japanese, korean) KataGo may not precisely match the ruleset in such +# cases but will do its best. + +dots = true +rules = dots +defaultBoardXSize = 14 +defaultBoardYSize = 14 + +# By default, the "rules" parameter is used, but if you comment it out and +# uncomment one option in each of the sections below, you can specify an +# arbitrary combination of individual rules. + +# koRule = SIMPLE # Simple ko rules (triple ko = no result) +# koRule = POSITIONAL # Positional superko +# koRule = SITUATIONAL # Situational superko + +# scoringRule = AREA # Area scoring +# scoringRule = TERRITORY # Territory scoring (special computer-friendly territory rules) + +# taxRule = NONE # All surrounded empty points are scored +# taxRule = SEKI # Eyes in seki do NOT count as points +# taxRule = ALL # All groups are taxed up to 2 points for the two eyes needed to live + +# Is multiple-stone suicide legal? (Single-stone suicide is always illegal). +# multiStoneSuicideLegal = false +# multiStoneSuicideLegal = true + +# "Button go" - the first pass when area scoring awards 0.5 points and does +# not count for ending the game. +# Allows area scoring rulesets that have far simpler rules to achieve the same +# final scoring precision and reward for precise play as territory scoring. +# hasButton = false +# hasButton = true + +# Is this a human ruleset where it's okay to pass before having physically +# captured and removed all dead stones? +# friendlyPassOk = false +# friendlyPassOk = true + +# How handicap stones in handicap games are compensated +# whiteHandicapBonus = 0 # White gets no compensation for black's handicap stones (Tromp-taylor, NZ, JP) +# whiteHandicapBonus = N-1 # White gets N-1 points for black's N handicap stones (AGA) +# whiteHandicapBonus = N # White gets N points for black's N handicap stones (Chinese) + +# ------------------------------ +# Other rules hacks +# ------------------------------ +# Uncomment and change to adjust what board size KataGo uses upon startup +# by default when GTP doesn't specify. +# defaultBoardSize = 19 + +# By default, Katago will use the komi that the GUI or GTP controller tries to set. +# Uncomment and set this to have KataGo ignore the controller and always use this komi. +# ignoreGTPAndForceKomi = 7 + +# =========================================================================== +# Bot behavior +# =========================================================================== + +# ------------------------------ +# Resignation +# ------------------------------ + +# Resignation occurs if for at least resignConsecTurns in a row, the +# winLossUtility (on a [-1,1] scale) is below resignThreshold. +allowResignation = true +resignThreshold = -0.90 +resignConsecTurns = 3 + +# By default, KataGo may resign games that it is confidently losing even if they +# are very close in score. Uncomment and set this to avoid resigning games +# if the estimated difference is points is less than or equal to this. +# resignMinScoreDifference = 10 + +# Disallow resignation if turn number < resignMinMovesPerBoardArea * area of board. +# e.g 0.25 would prohibit resignation on 19x19 until after turn 361 * 0.25 ~= 90. +# resignMinMovesPerBoardArea = 0.00 + +# ------------------------------ +# Handicap +# ------------------------------ +# Assume that if black makes many moves in a row right at the start of the +# game, then the game is a handicap game. This is necessary on some servers +# and for some GUIs and also when initializing from many SGF files, which may +# set up a handicap game using repeated GTP "play" commands for black rather +# than GTP "place_free_handicap" commands; however, it may also lead to +# incorrect understanding of komi if whiteHandicapBonus is used and a server +# does not have such a practice. Uncomment and set to false to disable. +# assumeMultipleStartingBlackMovesAreHandicap = true + +# Makes katago dynamically adjust in handicap or altered-komi games to assume +# based on those game settings that it must be stronger or weaker than the +# opponent and to play accordingly. Greatly improves handicap strength by +# biasing winrates and scores to favor appropriate safe/aggressive play. +# Does NOT affect analysis (lz-analyze, kata-analyze, used by programs like +# Lizzie) so analysis remains unbiased. Uncomment and set this to 0 to disable +# this and make KataGo play the same always. +# dynamicPlayoutDoublingAdvantageCapPerOppLead = 0.045 + +# Instead of "dynamicPlayoutDoublingAdvantageCapPerOppLead", you can comment +# that out and uncomment and set "playoutDoublingAdvantage" to a fixed value +# from -3.0 to 3.0 that will not change dynamically. +# ALSO affects analysis tools (lz-analyze, kata-analyze, used by e.g. Lizzie). +# Negative makes KataGo behave as if it is much weaker than the opponent. +# Positive makes KataGo behave as if it is much stronger than the opponent. +# KataGo will adjust to favor safe/aggressive play as appropriate based on +# the combination of who is ahead and how much stronger/weaker it thinks it is, +# and report winrates and scores taking the strength difference into account. +# +# If this and "dynamicPlayoutDoublingAdvantageCapPerOppLead" are both set +# then dynamic will be used for all games and this fixed value will be used +# for analysis tools. +# playoutDoublingAdvantage = 0.0 + +# Uncomment one of these when using "playoutDoublingAdvantage" to enforce +# that it will only apply when KataGo plays as the specified color and will be +# negated when playing as the opposite color. +# playoutDoublingAdvantagePla = BLACK +# playoutDoublingAdvantagePla = WHITE + +# ------------------------------ +# Passing and cleanup +# ------------------------------ +# Make the bot never assume that its pass will end the game, even if passing +# would end and "win" under Tromp-Taylor rules. Usually this is a good idea +# when using it for analysis or playing on servers where scoring may be +# implemented non-tromp-taylorly. Uncomment and set to false to disable. +# conservativePass = true + +# When using territory scoring, self-play games continue beyond two passes +# with special cleanup rules that may be confusing for human players. This +# option prevents the special cleanup phases from being reachable when using +# the bot for GTP play. Uncomment and set to false to enable entering special +# cleanup. For example, if you are testing it against itself, or against +# another bot that has precisely implemented the rules documented at +# https://lightvector.github.io/KataGo/rules.html +# preventCleanupPhase = true + +# ------------------------------ +# Miscellaneous behavior +# ------------------------------ +# If the board is symmetric, search only one copy of each equivalent move. +# Attempts to also account for ko/superko, will not theoretically perfect for +# superko. Uncomment and set to false to disable. +# rootSymmetryPruning = true + +# Uncomment and set to true to avoid a particular joseki that some networks +# misevaluate, and also to improve opening diversity versus some particular +# other bots that like to play it all the time. +# avoidMYTDaggerHack = false + +# Prefer to avoid playing the same joseki in every corner of the board. +# Uncomment to set to a specific value. See "Avoid SGF patterns" section. +# By default: 0 (even games), 0.005 (handicap games) +# avoidRepeatedPatternUtility = 0.0 + +# Experimental logic to fight against mirror Go even with unfavorable komi. +# Uncomment to set to a specific value to use for both playing and analysis. +# By default: true when playing via GTP, but false when analyzing. +# antiMirror = true + +# Enable some hacks that mitigate rare instances when passing messes up deeper searches. +# enablePassingHacks = true + +# Uncomment and set this to true to prevent bad or bogus move sequences +# in the history leading to this position from affecting KataGo's move choices. +# Same as analysisIgnorePreRootHistory (see above) but applies to actual play. +# You can enable this if KataGo is being asked to play from positions that it did not +# choose the moves to reach. +# ignorePreRootHistory = false + +# =========================================================================== +# Search limits +# =========================================================================== + +# Terminology: +# "Playouts" is the number of new playouts of search performed each turn. +# "Visits" is the same as "Playouts" but also counts search performed on +# previous turns that is still applicable to this turn. +# "Time" is the time in seconds. + +# For example, if KataGo searched 200 nodes on the previous turn, and then +# after the opponent's reply, 50 nodes of its search tree was still valid, +# then a visit limit of 200 would allow KataGo to search 150 new nodes +# (for a final tree size of 200 nodes), whereas a playout limit of of 200 +# would allow KataGo to search 200 nodes (for a final tree size of 250 nodes). + +# Additionally, KataGo may also move before than the limit in order to +# obey time controls (e.g. byo-yomi, etc) if the GTP controller has +# told KataGo that the game has is being played with a given time control. + +# Limits for search on the current turn. +# If commented out or unspecified, the default is to have no limit. +maxVisits = 500 +# maxPlayouts = 300 +# maxTime = 10.0 + +# Ponder on the opponent's turn? +ponderingEnabled = false + +# Limits for search when pondering on the opponent's turn. +# If commented out or unspecified, the default is to have no limit. +# Limiting the maximum time is recommended so that KataGo won't burn CPU/GPU +# forever and/or run out of RAM if left unattended while pondering is enabled. +# maxVisitsPondering = 5000 +# maxPlayoutsPondering = 3000 +maxTimePondering = 60.0 + + +# ------------------------------ +# Other search limits and behavior +# ------------------------------ + +# Approx number of seconds to buffer for lag for GTP time controls - will +# move a bit faster assuming there is this much lag per move. +lagBuffer = 1.0 + +# YOU PROBABLY WANT TO TUNE THIS PARAMETER! +# The number of threads to use when searching. On powerful GPUs the optimal +# threads may be much higher than the number of CPU cores you have because +# many threads are needed to feed efficient large batches to the GPU. +# +# Run "./katago benchmark" to tune this parameter and test the effect +# of changes to any of other parameters. +numSearchThreads = 6 + +# Play a little faster if the opponent is passing, for human-friendliness. +# Comment these out to disable them, such as if running a controlled match +# where you are testing KataGo with fixed compute per move vs other bots. +searchFactorAfterOnePass = 0.50 +searchFactorAfterTwoPass = 0.25 + +# Play a little faster if super-winning, for human-friendliness. +# Comment these out to disable them, such as if running a controlled match +# where you are testing KataGo with fixed compute per move vs other bots. +searchFactorWhenWinning = 0.40 +searchFactorWhenWinningThreshold = 0.95 + +# =========================================================================== +# GPU settings +# =========================================================================== +# This section configures GPU settings. +# +# Maximum number of positions to send to a single GPU at once. The default +# value is roughly equal to numSearchThreads, but can be specified manually +# if running out of memory, or using multiple GPUs that expect to share work. +# nnMaxBatchSize = + +# Controls the neural network cache size, which is the primary RAM/memory use. +# KataGo will cache up to (2 ** nnCacheSizePowerOfTwo) many neural net +# evaluations in case of transpositions in the tree. +# Increase this to improve performance for searches with tens of thousands +# of visits or more. Decrease this to limit memory usage. +# If you're happy to do some math - each neural net entry takes roughly +# 1.5KB, except when using whole-board ownership/territory +# visualizations, where each entry will take roughly 3KB. The number of +# entries is (2 ** nnCacheSizePowerOfTwo). (E.g. 2 ** 18 = 262144.) +# You can compute roughly how much memory the cache will use based on this. +# nnCacheSizePowerOfTwo = 20 + +# Size of mutex pool for nnCache is (2 ** this). +# nnMutexPoolSizePowerOfTwo = 16 + +# Randomize board orientation when running neural net evals? Uncomment and +# set to false to disable. +# nnRandomize = true + +# If provided, force usage of a specific seed for nnRandomize. +# The default is to use a randomly generated seed. +# nnRandSeed = abcdefg + +# Uncomment and set to true to force GTP to use the maximum board size for +# internal buffers for the neural net. This will make KataGo slower when +# evaluating small boards, but will avoid a lengthy initialization time on every +# change of board size due to having to re-size the neural net buffers on the GPU. +# This can be useful for example, for OGS's persistent bot mode that uses a single +# bot instance to handle multiple games and may thrash between different board sizes +# if there are concurrent games of multiple sizes. +# gtpForceMaxNNSize = false + +# ------------------------------ +# Multiple GPUs +# ------------------------------ +# Set this to the number of GPUs to use or that are available. +# IMPORTANT: If more than 1, also uncomment the appropriate TensorRT +# or CUDA or OpenCL section. +# numNNServerThreadsPerModel = 1 + +# ------------------------------ +# TENSORRT GPU settings +# ------------------------------ +# These only apply when using the TENSORRT version of KataGo. + +# For one GPU: optionally uncomment this option and change if the GPU to +# use is not device 0. +# trtDeviceToUse = 0 + +# For two GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1. +# trtDeviceToUseThread0 = 0 +# trtDeviceToUseThread1 = 1 + +# For three GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1 and 2. +# trtDeviceToUseThread0 = 0 +# trtDeviceToUseThread1 = 1 +# trtDeviceToUseThread2 = 2 + +# The pattern continues for additional GPUs. + +# ------------------------------ +# CUDA GPU settings +# ------------------------------ +# These only apply when using the CUDA version of KataGo. + +# For one GPU: optionally uncomment and change this if the GPU you want to +# use is not device 0 +# cudaDeviceToUse = 0 + +# For two GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1. +# cudaDeviceToUseThread0 = 0 +# cudaDeviceToUseThread1 = 1 + +# For three GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1 and 2. +# cudaDeviceToUseThread0 = 0 +# cudaDeviceToUseThread1 = 1 +# cudaDeviceToUseThread2 = 2 + +# The pattern continues for additional GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability +# of your NVIDIA GPU. If you want to try to force a particular behavior +# you can uncomment these lines and change them to "true" or "false". +# cudaUseFP16 = auto +# cudaUseNHWC = auto + +# ------------------------------ +# Metal GPU settings +# ------------------------------ +# These only apply when using the METAL version of KataGo. + +# For one Metal instance: KataGo will automatically use the default device. +# metalDeviceToUse = 0 + +# For two Metal instance: Uncomment these options, AND set numNNServerThreadsPerModel = 2 above. +# This will create two Metal instances, best overlapping the GPU and CPU execution. +# metalDeviceToUseThread0 = 0 +# metalDeviceToUseThread1 = 1 + +# The pattern continues for additional Metal instances. + +# ------------------------------ +# OpenCL GPU settings +# ------------------------------ +# These only apply when using the OpenCL version of KataGo. + +# Uncomment and set to true to tune OpenCL for every board size separately, +# rather than only the largest possible size. +# openclReTunePerBoardSize = false + +# For one GPU: optionally uncomment and change this if the best device to use is guessed incorrectly. +# The default behavior tries to guess the 'best' GPU or device on your system to use, usually it will be a good guess. +# openclDeviceToUse = 0 + +# For two GPUs: Uncomment these two lines and replace X and Y with the device ids of the devices you want to use. +# It might NOT be 0 and 1, some computers will have many OpenCL devices. You can see what the devices are when +# KataGo starts up - it should print or log all the devices it finds. +# (AND also set numNNServerThreadsPerModel above) +# openclDeviceToUseThread0 = X +# openclDeviceToUseThread1 = Y + +# For three GPUs: Uncomment these three lines and replace X and Y and Z with the device ids of the devices you want to use. +# It might NOT be 0 and 1 and 2, some computers will have many OpenCL devices. You can see what the devices are when +# KataGo starts up - it should print or log all the devices it finds. +# (AND also set numNNServerThreadsPerModel above) +# openclDeviceToUseThread0 = X +# openclDeviceToUseThread1 = Y +# openclDeviceToUseThread2 = Z + +# The pattern continues for additional GPUs. + +# KataGo will automatically use FP16 or not based on testing your GPU during +# tuning. If you want to try to force a particular behavior though you can +# uncomment this option and change it to "true" or "false". This is a fairly +# blunt setting - more detailed settings are testable by rerunning the tuner +# with various arguments (./katago tuner). +# openclUseFP16 = auto + +# ------------------------------ +# Eigen-specific settings +# ------------------------------ +# These only apply when using the Eigen (pure CPU) version of KataGo. + +# Number of CPU threads for evaluating the neural net on the Eigen backend. +# +# Default: numSearchThreads +# numEigenThreadsPerModel = X + +# =========================================================================== +# Root move selection and biases +# =========================================================================== +# Uncomment and edit any of the below values to change them from their default. + +# If provided, force usage of a specific seed for various random things in +# the search. The default is to use a random seed. +# searchRandSeed = hijklmn + +# Temperature for the early game, randomize between chosen moves with +# this temperature +# chosenMoveTemperatureEarly = 0.5 + +# Decay temperature for the early game by 0.5 every this many moves, +# scaled with board size. +# chosenMoveTemperatureHalflife = 19 + +# At the end of search after the early game, randomize between chosen +# moves with this temperature +# chosenMoveTemperature = 0.10 + +# Subtract this many visits from each move prior to applying +# chosenMoveTemperature (unless all moves have too few visits) to downweight +# unlikely moves +# chosenMoveSubtract = 0 + +# The same as chosenMoveSubtract but only prunes moves that fall below +# the threshold. This setting does not affect chosenMoveSubtract. +# chosenMovePrune = 1 + +# Number of symmetries to sample (without replacement) and average at the root +# rootNumSymmetriesToSample = 1 + +# Using LCB for move selection? +# useLcbForSelection = true + +# How many stdevs a move needs to be better than another for LCB selection +# lcbStdevs = 5.0 + +# Only use LCB override when a move has this proportion of visits as the +# top move. +# minVisitPropForLCB = 0.15 + +# =========================================================================== +# Internal params +# =========================================================================== +# Uncomment and edit any of the below values to change them from their default. + +# Scales the utility of winning/losing +# winLossUtilityFactor = 1.0 + +# Scales the utility for trying to maximize score +# staticScoreUtilityFactor = 0.10 +# dynamicScoreUtilityFactor = 0.30 + +# Adjust dynamic score center this proportion of the way towards zero, +# capped at a reasonable amount. +# dynamicScoreCenterZeroWeight = 0.20 +# dynamicScoreCenterScale = 0.75 + +# The utility of getting a "no result" due to triple ko or other long cycle +# in non-superko rulesets (-1 to 1) +# noResultUtilityForWhite = 0.0 + +# The number of wins that a draw counts as, for white. (0 to 1) +# drawEquivalentWinsForWhite = 0.5 + +# Exploration constant for mcts +# cpuctExploration = 1.0 +# cpuctExplorationLog = 0.45 + +# Parameters that control exploring more in volatile positions, exploring +# less in stable positions. +# cpuctUtilityStdevPrior = 0.40 +# cpuctUtilityStdevPriorWeight = 2.0 +# cpuctUtilityStdevScale = 0.85 + +# FPU reduction constant for mcts +# fpuReductionMax = 0.2 +# rootFpuReductionMax = 0.1 +# fpuParentWeightByVisitedPolicy = true + +# Parameters that control weighting of evals based on the net's own +# self-reported uncertainty. +# useUncertainty = true +# uncertaintyExponent = 1.0 +# uncertaintyCoeff = 0.25 + +# Explore using optimistic policy +# rootPolicyOptimism = 0.2 +# policyOptimism = 1.0 + +# Amount to apply a downweighting of children with very bad values relative +# to good ones. +# valueWeightExponent = 0.25 + +# Slight incentive for the bot to behave human-like with regard to passing at +# the end, filling the dame, not wasting time playing in its own territory, +# etc., and not play moves that are equivalent in terms of points but a bit +# more unfriendly to humans. +# rootEndingBonusPoints = 0.5 + +# Make the bot prune useless moves that are just prolonging the game to +# avoid losing yet. +# rootPruneUselessMoves = true + +# Apply bias correction based on local pattern keys +# subtreeValueBiasFactor = 0.45 +# subtreeValueBiasWeightExponent = 0.85 + +# Use graph search rather than tree search - identify and share search for +# transpositions. +# useGraphSearch = true + +# How much to shard the node table for search synchronization +# nodeTableShardsPowerOfTwo = 16 + +# How many virtual losses to add when a thread descends through a node +# numVirtualLossesPerThread = 1 + +# Improve the quality of evals under heavy multithreading +# useNoisePruning = true + +# =========================================================================== +# Automatic avoid patterns +# =========================================================================== +# The parameters in this section provide a way to bias away from moves that match +# patterns that this instance of KataGo has played in previous games, by auto-saving +# moves to a directory and then auto-loading and biasing against them each new game. +# Uncomment them to use them. When using this feature, all parameters must be specified. + +# Different board sizes are tracked separately, but all board sizes share the same +# parameters by default. Every parameter *except* for autoAvoidRepeatDir and +# autoAvoidRepeatSaveChunkSize can be overridden per board size, +# e.g. "autoAvoidRepeatMinTurnNumber13x13". You must ALSO specify the +# defaults even if you have specified board-size-specific values. + +# Directory to auto-save moves KataGo plays, to avoid them in future games. +# You can create a new empty directory and put its path here. +# If you run parallel instances of KataGo, use different directories if you +# want them not to share their biases, use the same directory if you want +# all of them to bias away from past moves that any of them have played. +# KataGo will also automatically DELETE old data in this directory, so it is +# recommended that if you do share the same directory between parallel instances, +# that they all use the same settings that affect data saving/deletion. +# autoAvoidRepeatDir = PATH_TO_NEW_DIRECTORY + +# Penalize this much utility per matching move. +# Values that are too large may lead to bad play. The value of 0.004 is fairly large +# and might be large enough to result in some early weird/bad moves when trying +# to avoid past games' moves if enough games begin the same way. You can experiment with it. +# autoAvoidRepeatUtility = 0.004 + +# Per each new move saved, exponentially decay prior saved moves by this factor. +# This way, the bias against moves from many games ago is gradually phased out. +# For example, 0.9995 = 1 - 1/2000, so would be roughly 2000 prior moves worth of +# penalty weight remembered, in steady state. Depending on what turn number range +# you are saving, this might equate to a different number of games For example +# saving the first 50 moves per game would make this roughly 2000 / 50 = 40 games +# worth of memory. +# autoAvoidRepeatLambda = 0.9995 + +# Affects data saving/deletion. +# When the number of saved moves exceeds this, outright delete them to avoid too many +# files and disk space building up. Also may affect the speed of saving/loading on start +# of each game if this is set large and a lot of data builds up. +# autoAvoidRepeatMaxPoses = 10000 + +# Affects data saving/deletion. +# Only save data for moves within this turn number range of those games. +# E.g. setting autoAvoidRepeatMinTurnNumber to a number like 4 or 5 would tend to make +# KataGo not develop a bias against the initial 3-4 and 4-4 corner moves in almost every game. +# autoAvoidRepeatMinTurnNumber = 0 +# autoAvoidRepeatMaxTurnNumber = 50 + +# Affects data saving/deletion. +# Within a single run of a program, wait to accumulate this many samples +# (possibly across multiple clear_boards/games) before saving the data. +# Can help to avoid writing too many small files to disk, especially when GTP is used +# in a way that clears the board very frequently (e.g. gtp2ogs pooled manager). +# autoAvoidRepeatSaveChunkSize = 200 + +# =========================================================================== +# Avoid SGF patterns +# =========================================================================== +# The parameters in this section provide a way to avoid moves that follow +# specific patterns based on a set of SGF files loaded upon startup. +# This is basically the same as the above "Automatic avoid patterns" section +# above except you supply your own SGF files to avoid moves from. +# Uncomment them to use this feature. Additionally, if the SGF file +# contains the string %SKIP% in a comment on a move, that move will be +# ignored for this purpose. + +# Load SGF files from this directory when the engine is started +# (only on startup, will not reload unless engine is restarted) +# avoidSgfPatternDirs = path/to/directory/with/sgfs/ +# You can also surround the file path in double quotes if the file path contains trailing spaces or hash signs. +# Within double quotes, backslashes are escape characters. +# avoidSgfPatternDirs = "path/to/directory/with/sgfs/" + +# Penalize this much utility per matching move. +# Set this negative if you instead want to favor SGF patterns instead of +# penalizing them. This number does not need to be large, even 0.001 will +# make a difference. Values that are too large may lead to bad play. +# avoidSgfPatternUtility = 0.001 + +# Optional - load only the newest this many files +# avoidSgfPatternMaxFiles = 20 + +# Optional - Penalty is multiplied by this per each older SGF file, so that +# old SGF files matter less than newer ones. +# avoidSgfPatternLambda = 0.90 + +# Optional - pay attention only to moves made by players with this name. +# For example, set it to the name that your bot's past games will show up +# as in the SGF, so that the bot will only avoid repeating moves that itself +# made in past games, not the moves that its opponents made. +# avoidSgfPatternAllowedNames = my-ogs-bot-name1,my-ogs-bot-name2 + +# Optional - Ignore moves in SGF files that occurred before this turn number. +# avoidSgfPatternMinTurnNumber = 0 + +# For more avoid patterns: +# You can also specify a second set of parameters, and a third, fourth, +# etc. by numbering 2,3,4,... +# +# avoidSgf2PatternDirs = ... +# avoidSgf2PatternUtility = ... +# avoidSgf2PatternMaxFiles = ... +# avoidSgf2PatternLambda = ... +# avoidSgf2PatternAllowedNames = ... +# avoidSgf2PatternMinTurnNumber = ... + diff --git a/cpp/configs/training/gatekeeper1_dots.cfg b/cpp/configs/training/gatekeeper1_dots.cfg new file mode 100644 index 000000000..dc9eee536 --- /dev/null +++ b/cpp/configs/training/gatekeeper1_dots.cfg @@ -0,0 +1,117 @@ + +# Logs------------------------------------------------------------------------------------ + +logSearchInfo = false +logMoves = false +logGamesEvery = 10 +logToStdout = true + +# Fancy game selfplay settings-------------------------------------------------------------------- + +# startPosesProb = 0.0 # Play this proportion of games starting from SGF positions +# startPosesFromSgfDir = DIRECTORYPATH # Load SGFs from this dir +# startPosesLoadProb = 1.0 # Only load each position from each SGF with this chance (save memory) +# startPosesTurnWeightLambda = 0 # 0 = equal weight 0.01 = decrease probability by 1% per turn -0.01 = increase probability by 1% per turn. + +# Match----------------------------------------------------------------------------------- + +numGameThreads = 16 +maxMovesPerGame = 1600 +numGamesPerGating = 100 + +allowResignation = true +resignThreshold = -0.90 +resignConsecTurns = 5 + +# Disabled, since we're not using any root noise and such +# Could have a slight weirdness on rootEndingBonusPoints, but shouldn't be a big deal. +# clearBotBeforeSearch = true + +# Rules------------------------------------------------------------------------------------ + +dots = true +multiStoneSuicideLegals = false,true + +#bSizesXY=20-20,36-36,39-32,24-24,30-30 +#bSizeRelProbs=10,2,1 + +bSizesXY=10-10,12-12,14-14 +bSizeRelProbs=1,1,2 + +startPoses=CROSS,CROSS,CROSS,CROSS,CROSS,CROSS,CROSS_2 +startPosesAreRandom=false,true +dotsCaptureEmptyBases=false,false,false,false,false,true +#dotsFreeCapturedDots=true,true,true,false + +#komiAuto = True # Automatically adjust komi to what the neural nets think are fair +komiMean = 0.0 # Specify explicit komi +komiStdev = 0.5 # Standard deviation of random variation to komi. +handicapProb = 0.0 # Probability of handicap game +handicapCompensateKomiProb = 0.0 # Probability of compensating komi to fair during handicap game +# numExtraBlackFixed = 3 # When playing handicap games, always use exactly this many extra black moves + +# Search limits----------------------------------------------------------------------------------- +maxVisits = 150 +numSearchThreads = 16 + +# GPU Settings------------------------------------------------------------------------------- + +nnMaxBatchSize = 128 +nnCacheSizePowerOfTwo = 21 +nnMutexPoolSizePowerOfTwo = 15 +numNNServerThreadsPerModel = 1 +nnRandomize = true + +# CUDA GPU settings-------------------------------------- +# cudaDeviceToUse = 0 #use device 0 for all server threads (numNNServerThreadsPerModel) unless otherwise specified per-model or per-thread-per-model +# cudaDeviceToUseModel0 = 3 #use device 3 for model 0 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel1 = 2 #use device 2 for model 1 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel0Thread0 = 3 #use device 3 for model 0, server thread 0 +# cudaDeviceToUseModel0Thread1 = 2 #use device 2 for model 0, server thread 1 + +cudaUseFP16 = auto +cudaUseNHWC = auto + +# Root move selection and biases------------------------------------------------------------------------------ + +chosenMoveTemperatureEarly = 0.5 +chosenMoveTemperatureHalflife = 19 +chosenMoveTemperature = 0.2 +chosenMoveSubtract = 0 +chosenMovePrune = 1 + +useLcbForSelection = true +lcbStdevs = 5.0 +minVisitPropForLCB = 0.15 + +# Internal params------------------------------------------------------------------------------ + +winLossUtilityFactor = 1.0 +staticScoreUtilityFactor = 0.00 +dynamicScoreUtilityFactor = 0.25 +dynamicScoreCenterZeroWeight = 0.25 +dynamicScoreCenterScale = 0.50 +noResultUtilityForWhite = 0.0 +drawEquivalentWinsForWhite = 0.5 + +rootEndingBonusPoints = 0.5 +rootPruneUselessMoves = true + +cpuctExploration = 1.1 +cpuctExplorationLog = 0.0 +fpuReductionMax = 0.2 +rootFpuReductionMax = 0.1 +valueWeightExponent = 0.5 +policyOptimism = 1.0 +rootPolicyOptimism = 0.0 + +subtreeValueBiasFactor = 0.35 +subtreeValueBiasWeightExponent = 0.8 +useNonBuggyLcb = true +useGraphSearch = true + +useUncertainty = true +uncertaintyExponent = 1.0 +uncertaintyCoeff = 0.25 + +numVirtualLossesPerThread = 1 diff --git a/cpp/configs/training/selfplay1_dots.cfg b/cpp/configs/training/selfplay1_dots.cfg new file mode 100644 index 000000000..77f535726 --- /dev/null +++ b/cpp/configs/training/selfplay1_dots.cfg @@ -0,0 +1,196 @@ + +# Logs------------------------------------------------------------------------------------ + +logSearchInfo = false +logMoves = false +logGamesEvery = 1 +logToStdout = true + +# Data writing----------------------------------------------------------------------------------- + +# Spatial size of tensors of data, must match -pos-len in python/train.py and be at least as large as the +# largest boardsize in the data. If you train only on smaller board sizes, you can decrease this and +# pass in a smaller -pos-len to python/train.py (e.g. modify "-pos-len 19" in selfplay/train.sh script), +# to train faster, although it may be awkward if you ever want to increase it again later since data with +# different values cannot be shuffled together. +dataBoardLen = 39 +dataBoardLenY = 39 # TODO: fix handling of rectangular size + +maxDataQueueSize = 2000 +maxRowsPerTrainFile = 10000 +firstFileRandMinProp = 0.15 + +# Fancy game selfplay settings-------------------------------------------------------------------- + +# These take dominance - if a game is forked for any of these reasons, that will be the next game played. +# For forked games, randomization of rules and "initGamesWithPolicy" is disabled, komi is made fair with prob "forkCompensateKomiProb". +earlyForkGameProb = 0.04 # Fork to try alternative opening variety with this probability +earlyForkGameExpectedMoveProp = 0.025 # Fork roughly within the first (boardArea * this) many moves +forkGameProb = 0.01 # Fork to try alternative crazy move anywhere in game with this probability if not early forking +forkGameMinChoices = 3 # Choose from the best of at least this many random choices +earlyForkGameMaxChoices = 12 # Choose from the best of at most this many random choices +forkGameMaxChoices = 36 # Choose from the best of at most this many random choices +# Otherwise, with this probability, learn a bit more about the differing evaluation of seki in different rulesets. +sekiForkHackProb = 0.02 + +# Otherwise, play some proportion of games starting from SGF positions, with randomized rules (ignoring the sgf rules) +# On SGF positions, high temperature policy init is allowed +# startPosesProb = 0.0 # Play this proportion of games starting from SGF positions +# startPosesFromSgfDir = DIRECTORYPATH # Load SGFs from this dir +# startPosesLoadProb = 1.0 # Only load each position from each SGF with this chance (save memory) +# startPosesTurnWeightLambda = 0 # 0 = equal weight 0.01 = decrease probability by 1% per turn -0.01 = increase probability by 1% per turn. +# startPosesPolicyInitAreaProp = 0.0 # Same as policyInitAreaProp but for SGF positions + +# Otherwise, play some proportion of games starting from hint positions (generated using "dataminesgfs" command), with randomized rules. +# On hint positions, "initGamesWithPolicy" does not apply. +# hintPosesProb = 0.0 +# hintPosesDir = DIRECTORYPATH + +# Otherwise we are playing a "normal" game, potentially with handicap stones, depending on "handicapProb", and +# potentially with komi randomization, depending on things like "komiStdev", and potentially with different +# board sizes, etc. + +# Most of the remaining parameters here below apply regardless of the initialization, although a few of them +# vary depending on handicap vs normal game, and some are explicitly disabled (e.g. initGamesWithPolicy on hint positions). + +initGamesWithPolicy = true # Play the first few moves of a game high-temperaturely from policy +policyInitAreaProp = 0.04 # The avg number of moves to play +compensateAfterPolicyInitProb = 0.2 # Additionally make komi fair this often after the high-temperature moves. +sidePositionProb = 0.020 # With this probability, train on refuting bad alternative moves. + +cheapSearchProb = 0.75 # Do cheap searches with this probaiblity +cheapSearchVisits = 100 # Number of visits for cheap search +cheapSearchTargetWeight = 0.0 # Training weight for cheap search + +reduceVisits = true # Reduce visits when one side is winning +reduceVisitsThreshold = 0.9 # How winning a side needs to be (winrate) +reduceVisitsThresholdLookback = 3 # How many consecutive turns needed to be that winning +reducedVisitsMin = 100 # Minimum number of visits (never reduce below this) +reducedVisitsWeight = 0.1 # Minimum training weight + +handicapAsymmetricPlayoutProb = 0.5 # In handicap games, play with unbalanced players with this probablity +normalAsymmetricPlayoutProb = 0.01 # In regular games, play with unbalanced players with this probability +maxAsymmetricRatio = 8.0 # Max ratio to unbalance visits between players +minAsymmetricCompensateKomiProb = 0.4 # Compensate komi with at least this probability for unbalanced players + +policySurpriseDataWeight = 0.5 # This proportion of training weight should be concentrated on surprising moves +valueSurpriseDataWeight = 0.1 # This proportion of training weight should be concentrated on surprising position results + +estimateLeadProb = 0.05 # Train lead, rather than just scoremean. Consumes a decent number of extra visits, can be quite slow using low visits to set too high. +switchNetsMidGame = true # When a new neural net is loaded, switch to it immediately instead of waiting for new game +# fancyKomiVarying = true # In non-compensated handicap and fork games, vary komi to better learn komi and large score differences that would never happen in even games. + +# Match----------------------------------------------------------------------------------- + +numGameThreads = 16 +maxMovesPerGame = 1600 + +# Rules------------------------------------------------------------------------------------ + +dots = true +multiStoneSuicideLegals = false,true + +#bSizesXY=20-20,36-36,39-32,24-24,30-30 +#bSizeRelProbs=10,2,1 + +#bSizesXY=20-20,16-16,10-10,24-24 +#bSizeRelProbs=3,2,2,1 + +#bSizesXY=8-8,10-10,9-9,8-10,10-9,12-12 +#bSizeRelProbs=4,4,2,1,1,1 + +bSizesXY=10-10,12-12,14-14 +bSizeRelProbs=1,2,1 + +startPoses=CROSS,CROSS,CROSS,CROSS_2,CROSS_4,CROSS_4 +startPosesAreRandom=false +dotsCaptureEmptyBases=false,false,false,false,false,false,false,true +#dotsFreeCapturedDots=true,true,true,false + +# komiAuto = True # Automatically adjust komi to what the neural nets think are fair based on the empty board, but still apply komiStdev. +komiMean = 0.0 # Specify explicit komi +komiStdev = 0.5 # Standard deviation of random variation to komi. +komiBigStdevProb = 0.0 # Probability of applying komiBigStdev +komiBigStdev = 0.0 # Standard deviation of random big variation to komi + +handicapProb = 0.1 # Probability of handicap game +handicapCompensateKomiProb = 0.00 # In handicap games, adjust komi to fair with this probability based on the handicap placement +forkCompensateKomiProb = 0.00 # For forks, adjust komi to fair with this probability based on the forked position +sgfCompensateKomiProb = 0.00 # For sgfs, adjust komi to fair with this probability based on the specific starting position + +drawRandRadius = 0.5 +noResultStdev = 0.166666666 + +# Search limits----------------------------------------------------------------------------------- + +maxVisits = 600 +numSearchThreads = 16 + +# GPU Settings------------------------------------------------------------------------------- + +nnMaxBatchSize = 128 +nnCacheSizePowerOfTwo = 21 +nnMutexPoolSizePowerOfTwo = 15 +numNNServerThreadsPerModel = 1 +nnRandomize = true + +# CUDA GPU settings-------------------------------------- +# cudaDeviceToUse = 0 #use device 0 for all server threads (numNNServerThreadsPerModel) unless otherwise specified per-model or per-thread-per-model +# cudaDeviceToUseModel0 = 3 #use device 3 for model 0 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel1 = 2 #use device 2 for model 1 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel0Thread0 = 3 #use device 3 for model 0, server thread 0 +# cudaDeviceToUseModel0Thread1 = 2 #use device 2 for model 0, server thread 1 + +cudaUseFP16 = auto +cudaUseNHWC = auto + +# Root move selection and biases------------------------------------------------------------------------------ + +chosenMoveTemperatureEarly = 0.75 +chosenMoveTemperatureHalflife = 19 +chosenMoveTemperature = 0.15 +chosenMoveSubtract = 0 +chosenMovePrune = 1 + +rootNoiseEnabled = true +rootDirichletNoiseTotalConcentration = 10.83 +rootDirichletNoiseWeight = 0.25 + +rootDesiredPerChildVisitsCoeff = 2 +rootNumSymmetriesToSample = 4 + +useLcbForSelection = true +lcbStdevs = 5.0 +minVisitPropForLCB = 0.15 + +# Internal params------------------------------------------------------------------------------ + +winLossUtilityFactor = 1.0 +staticScoreUtilityFactor = 0.00 +dynamicScoreUtilityFactor = 0.40 +dynamicScoreCenterZeroWeight = 0.25 +dynamicScoreCenterScale = 0.50 +noResultUtilityForWhite = 0.0 +drawEquivalentWinsForWhite = 0.5 + +rootEndingBonusPoints = 0.5 +rootPruneUselessMoves = true + +rootPolicyTemperatureEarly = 1.25 +rootPolicyTemperature = 1.1 + +cpuctExploration = 1.1 +cpuctExplorationLog = 0.0 +fpuReductionMax = 0.2 +rootFpuReductionMax = 0.0 + +numVirtualLossesPerThread = 1 + +# These parameters didn't exist historically during early KataGo runs +valueWeightExponent = 0.5 +subtreeValueBiasFactor = 0.30 +subtreeValueBiasWeightExponent = 0.8 +useNonBuggyLcb = true +useGraphSearch = true +fpuParentWeightByVisitedPolicy = true +fpuParentWeightByVisitedPolicyPow = 2.0 diff --git a/cpp/core/commontypes.h b/cpp/core/commontypes.h index 8f1d57656..51441726f 100644 --- a/cpp/core/commontypes.h +++ b/cpp/core/commontypes.h @@ -1,6 +1,8 @@ #ifndef COMMONTYPES_H #define COMMONTYPES_H +#include + struct enabled_t { enum value { False, True, Auto }; value x; @@ -11,7 +13,7 @@ struct enabled_t { constexpr bool operator==(enabled_t a) const { return x == a.x; } constexpr bool operator!=(enabled_t a) const { return x != a.x; } - std::string toString() { + [[nodiscard]] std::string toString() const { return x == True ? "true" : x == False ? "false" : "auto"; } diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index e78388b63..82fe34c9e 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -2,8 +2,8 @@ #include "../core/fileutils.h" -#include #include +#include #include using namespace std; @@ -113,8 +113,6 @@ void ConfigParser::processIncludedFile(const std::string &fname) { baseDirs.pop_back(); } - - bool ConfigParser::parseKeyValue(const std::string& trimmedLine, std::string& key, std::string& value) { // Parse trimmed line, taking into account comments and quoting. key.clear(); @@ -519,56 +517,69 @@ std::string ConfigParser::firstFoundOrEmpty(const std::vector& poss if(contains(key)) return key; } - return string(); + return {}; } +string ConfigParser::getString(const string& key, const set& possibles) { + const auto str = tryGetString(key); -string ConfigParser::getString(const string& key) { - auto iter = keyValues.find(key); - if(iter == keyValues.end()) + if (str == std::nullopt) throw IOError("Could not find key '" + key + "' in config file " + fileName); - { - std::lock_guard lock(usedKeysMutex); - usedKeys.insert(key); - } + auto value = str.value(); - return iter->second; -} + if (!possibles.empty()) { + if(possibles.find(value) == possibles.end()) + throw IOError("Key '" + key + "' must be one of (" + Global::concat(possibles,"|") + ") in config file " + fileName); + } -string ConfigParser::getString(const string& key, const set& possibles) { - string value = getString(key); - if(possibles.find(value) == possibles.end()) - throw IOError("Key '" + key + "' must be one of (" + Global::concat(possibles,"|") + ") in config file " + fileName); return value; } -vector ConfigParser::getStrings(const string& key) { - return Global::split(getString(key),','); -} +vector ConfigParser::getStrings(const string& key, const set& possibles, const bool nonEmptyTrim) { + vector values = Global::split(getString(key),','); -vector ConfigParser::getStringsNonEmptyTrim(const string& key) { - vector raw = Global::split(getString(key),','); - vector trimmed; - for(size_t i = 0; i trimmedStrings; + for(const auto& s : values) { + if (string trimmed = Global::trim(s); !trimmed.empty()) { + trimmedStrings.push_back(trimmed); + } + } + values = trimmedStrings; } - return trimmed; -} -vector ConfigParser::getStrings(const string& key, const set& possibles) { - vector values = getStrings(key); - for(size_t i = 0; i ConfigParser::tryGetString(const string& key) { + const auto iter = keyValues.find(key); + if(iter == keyValues.end()) { + return std::nullopt; + } + + { + std::lock_guard lock(usedKeysMutex); + usedKeys.insert(key); + } + + return iter->second; +} + +bool ConfigParser::getBoolOrDefault(const std::string& key, const bool defaultValue) { + if (contains(key)) { + return getBool(key); + } + + return defaultValue; +} bool ConfigParser::getBool(const string& key) { string value = getString(key); @@ -577,6 +588,7 @@ bool ConfigParser::getBool(const string& key) { throw IOError("Could not parse '" + value + "' as bool for key '" + key + "' in config file " + fileName); return x; } + vector ConfigParser::getBools(const string& key) { vector values = getStrings(key); vector ret; @@ -598,14 +610,7 @@ enabled_t ConfigParser::getEnabled(const string& key) { return x; } -int ConfigParser::getInt(const string& key) { - string value = getString(key); - int x; - if(!Global::tryStringToInt(value,x)) - throw IOError("Could not parse '" + value + "' as int for key '" + key + "' in config file " + fileName); - return x; -} -int ConfigParser::getInt(const string& key, int min, int max) { +int ConfigParser::getInt(const string& key, const int min, const int max) { assert(min <= max); string value = getString(key); int x; @@ -615,19 +620,8 @@ int ConfigParser::getInt(const string& key, int min, int max) { throw IOError("Key '" + key + "' must be in the range " + Global::intToString(min) + " to " + Global::intToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInts(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInts(const string& key, int min, int max) { + +vector ConfigParser::getInts(const string& key, const int min, const int max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getInts(const string& key, int min, int max) { } return ret; } -vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, int min, int max) { + +vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, const int min, const int max1, const int max2) { std::vector pairStrs = getStrings(key); std::vector> ret; for(const string& pairStr: pairStrs) { @@ -662,23 +657,15 @@ vector> ConfigParser::getNonNegativeIntDashedPairs(const stri if(!suc) throw IOError("Could not parse '" + pairStr + "' as a pair of integers separated by a dash for key '" + key + "' in config file " + fileName); - if(p0 < min || p0 > max || p1 < min || p1 > max) - throw IOError("Expected key '" + key + "' to have all values range " + Global::intToString(min) + " to " + Global::intToString(max) + " in config file " + fileName); + if(p0 < min || p0 > max1 || p1 < min || p1 > max2) + throw IOError("Expected key '" + key + "' to have all values range " + Global::intToString(min) + " to (" + Global::intToString(max1) + ", " + Global::intToString(max2) + ") in config file " + fileName); ret.push_back(std::make_pair(p0,p1)); } return ret; } - -int64_t ConfigParser::getInt64(const string& key) { - string value = getString(key); - int64_t x; - if(!Global::tryStringToInt64(value,x)) - throw IOError("Could not parse '" + value + "' as int64_t for key '" + key + "' in config file " + fileName); - return x; -} -int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { +int64_t ConfigParser::getInt64(const string& key, const int64_t min, const int64_t max) { assert(min <= max); string value = getString(key); int64_t x; @@ -688,19 +675,8 @@ int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { throw IOError("Key '" + key + "' must be in the range " + Global::int64ToString(min) + " to " + Global::int64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t max) { + +vector ConfigParser::getInt64s(const string& key, const int64_t min, const int64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t return ret; } - -uint64_t ConfigParser::getUInt64(const string& key) { - string value = getString(key); - uint64_t x; - if(!Global::tryStringToUInt64(value,x)) - throw IOError("Could not parse '" + value + "' as uint64_t for key '" + key + "' in config file " + fileName); - return x; -} -uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) { +uint64_t ConfigParser::getUInt64(const string& key, const uint64_t min, const uint64_t max) { assert(min <= max); string value = getString(key); uint64_t x; @@ -733,19 +701,8 @@ uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) throw IOError("Key '" + key + "' must be in the range " + Global::uint64ToString(min) + " to " + Global::uint64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getUInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint64_t max) { + +vector ConfigParser::getUInt64s(const string& key, const uint64_t min, const uint64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint6 return ret; } - -float ConfigParser::getFloat(const string& key) { - string value = getString(key); - float x; - if(!Global::tryStringToFloat(value,x)) - throw IOError("Could not parse '" + value + "' as float for key '" + key + "' in config file " + fileName); - return x; -} -float ConfigParser::getFloat(const string& key, float min, float max) { +float ConfigParser::getFloat(const string& key, const float min, const float max) { assert(min <= max); string value = getString(key); float x; @@ -780,19 +729,8 @@ float ConfigParser::getFloat(const string& key, float min, float max) { throw IOError("Key '" + key + "' must be in the range " + Global::floatToString(min) + " to " + Global::floatToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getFloats(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { + +vector ConfigParser::getFloats(const string& key, const float min, const float max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { return ret; } - -double ConfigParser::getDouble(const string& key) { - string value = getString(key); - double x; - if(!Global::tryStringToDouble(value,x)) - throw IOError("Could not parse '" + value + "' as double for key '" + key + "' in config file " + fileName); - return x; -} -double ConfigParser::getDouble(const string& key, double min, double max) { +double ConfigParser::getDouble(const string& key, const double min, const double max) { assert(min <= max); string value = getString(key); double x; @@ -829,19 +759,8 @@ double ConfigParser::getDouble(const string& key, double min, double max) { throw IOError("Key '" + key + "' must be in the range " + Global::doubleToString(min) + " to " + Global::doubleToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getDoubles(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getDoubles(const string& key, double min, double max) { + +vector ConfigParser::getDoubles(const string& key, const double min, const double max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i +#include +#include "../core/commontypes.h" #include "../core/global.h" #include "../core/logger.h" -#include "../core/commontypes.h" /* Parses simple configs like: @@ -57,39 +58,28 @@ class ConfigParser { std::string firstFoundOrFail(const std::vector& possibleKeys) const; std::string firstFoundOrEmpty(const std::vector& possibleKeys) const; - std::string getString(const std::string& key); + std::string getString(const std::string& key, const std::set& possibles = {}); + std::vector getStrings(const std::string& key, const std::set& possibles = {}, bool nonEmptyTrim = false); + std::optional tryGetString(const std::string& key); + + bool getBoolOrDefault(const std::string& key, bool defaultValue); bool getBool(const std::string& key); enabled_t getEnabled(const std::string& key); - int getInt(const std::string& key); - int64_t getInt64(const std::string& key); - uint64_t getUInt64(const std::string& key); - float getFloat(const std::string& key); - double getDouble(const std::string& key); - - std::string getString(const std::string& key, const std::set& possibles); - int getInt(const std::string& key, int min, int max); - int64_t getInt64(const std::string& key, int64_t min, int64_t max); - uint64_t getUInt64(const std::string& key, uint64_t min, uint64_t max); - float getFloat(const std::string& key, float min, float max); - double getDouble(const std::string& key, double min, double max); - - std::vector getStrings(const std::string& key); - std::vector getStringsNonEmptyTrim(const std::string& key); + + int getInt(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + int64_t getInt64(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + uint64_t getUInt64(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + float getFloat(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + double getDouble(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); + std::vector getBools(const std::string& key); - std::vector getInts(const std::string& key); - std::vector getInt64s(const std::string& key); - std::vector getUInt64s(const std::string& key); - std::vector getFloats(const std::string& key); - std::vector getDoubles(const std::string& key); - - std::vector getStrings(const std::string& key, const std::set& possibles); - std::vector getInts(const std::string& key, int min, int max); - std::vector getInt64s(const std::string& key, int64_t min, int64_t max); - std::vector getUInt64s(const std::string& key, uint64_t min, uint64_t max); - std::vector getFloats(const std::string& key, float min, float max); - std::vector getDoubles(const std::string& key, double min, double max); - - std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max); + std::vector getInts(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + std::vector getInt64s(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + std::vector getUInt64s(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + std::vector getFloats(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + std::vector getDoubles(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); + + std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max1, int max2); private: bool initialized; diff --git a/cpp/core/global.cpp b/cpp/core/global.cpp index b5620c022..a1e74635b 100644 --- a/cpp/core/global.cpp +++ b/cpp/core/global.cpp @@ -16,6 +16,8 @@ #include #include +#include "test.h" + using namespace std; //ERRORS---------------------------------- @@ -117,6 +119,41 @@ string Global::uint64ToHexString(uint64_t x) return s; } +static const std::string CoordAlphabet = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; +static const int CoordAlphabetLength = CoordAlphabet.length(); + +std::string Global::intToCoord(const int x) { + assert(x >= 0); + + int v = x + 1; + std::string out; + while (v > 0) { + out.push_back(CoordAlphabet[(v - 1) % CoordAlphabetLength]); + v = (v - 1) / CoordAlphabetLength; + } + std::reverse(out.begin(), out.end()); + return out; +} + +bool Global::tryCoordToInt(const std::string& coord, int& x) { + if (coord.empty()) return false; + + long long v = 0; + for (const char c : coord) { + const auto pos = CoordAlphabet.find(std::toupper(c)); + if (pos == std::string::npos) return false; + v = v * CoordAlphabetLength + (static_cast(pos) + 1); + if (v > static_cast(std::numeric_limits::max()) + 1LL) { + return false; + } + } + + // convert from 1-based bijective value back to 0-based index + x = static_cast(v - 1); + assert(x >= 0); + return true; +} + string Global::sizeToString(size_t x) { stringstream ss; @@ -706,3 +743,44 @@ double Global::roundDynamic(double x, int precision) { double inverseScale = pow(10.0,-roundingMagnitude); return roundStatic(x, inverseScale); } + +bool Global::isEqual(const float f1, const float f2) { + return std::fabs(f1 - f2) <= FLOAT_EPS; +} + +bool Global::isZero(const float f) { + return std::fabs(f) <= FLOAT_EPS; +} + +void Global::runTests() { + testAssert(isEqual(0.0f, 0.0f)); + testAssert(isEqual(FLOAT_EPS, FLOAT_EPS)); + testAssert(isZero(FLOAT_EPS)); + testAssert(isZero(-FLOAT_EPS)); + testAssert(!isEqual(42, 42 + 0.0001f)); + testAssert(!isEqual(-FLOAT_EPS, FLOAT_EPS)); + + auto checkCoordConversion = [](const string& coord, const int value) { + testAssert(toUpper(coord) == intToCoord(value)); + int newValue; + testAssert(tryCoordToInt(coord, newValue)); + testAssert(value == newValue); + }; + + checkCoordConversion("A", 0); + checkCoordConversion("B", 1); + checkCoordConversion("Z", 24); + checkCoordConversion("AA", 25); + checkCoordConversion("AB", 26); + checkCoordConversion("AZ", 49); + checkCoordConversion("az", 49); + checkCoordConversion("BA", 50); + checkCoordConversion("ZZ", 649); + checkCoordConversion("AAA", 650); + + int newValue; + testAssert(!tryCoordToInt("", newValue)); + testAssert(!tryCoordToInt("!", newValue)); + testAssert(tryCoordToInt("ABCDEFG", newValue)); + testAssert(!tryCoordToInt("ABCDEFGH", newValue)); // Too big number -> overflow +} diff --git a/cpp/core/global.h b/cpp/core/global.h index 62c34fa12..bb65afc5f 100644 --- a/cpp/core/global.h +++ b/cpp/core/global.h @@ -76,6 +76,10 @@ namespace Global std::string uint32ToHexString(uint32_t x); std::string uint64ToHexString(uint64_t x); std::string sizeToString(size_t x); + // Convert numbers to letters that are used on real Go board + // It matches the English alphabet except missing I that is ignored to avoid confusion with J + std::string intToCoord(int x); + bool tryCoordToInt(const std::string& coord, int& x); //String to conversions using the standard library parsing int stringToInt(const std::string& str); @@ -159,7 +163,13 @@ namespace Global //Round x to this many decimal digits of precision double roundDynamic(double x, int precision); -} + // Float comparison + constexpr float FLOAT_EPS = std::numeric_limits::epsilon(); + bool isEqual(float f1, float f2); + bool isZero(float f); + + void runTests(); +} // namespace Global struct StringError : public std::exception { std::string message; diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 5000d3139..433519565 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -127,8 +127,11 @@ static Loc parseSgfLocOrPass(const string& s, int xSize, int ySize) { static void writeSgfLoc(ostream& out, Loc loc, int xSize, int ySize) { if(xSize >= 53 || ySize >= 53) throw StringError("Writing coordinates for SGF files for board sizes >= 53 is not implemented"); - if(loc == Board::PASS_LOC || loc == Board::NULL_LOC) + if (loc == Board::PASS_LOC || loc == Board::NULL_LOC) return; + if (loc == Board::RESIGN_LOC) { + out << RESIGN_STR; + } int x = Location::getX(loc,xSize); int y = Location::getY(loc,xSize); const char* chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; @@ -136,6 +139,32 @@ static void writeSgfLoc(ostream& out, Loc loc, int xSize, int ySize) { out << chars[y]; } +static Rules getRulesFromSgf(const SgfNode& rootNode, const int xSize, const int ySize, const Rules* defaultRules) { + Rules rules; + const bool dotsGame = rootNode.getIsDotsGame(); + if (defaultRules == nullptr || rootNode.hasProperty("RU")) { + rules = rootNode.getRulesFromRUTagOrFail(dotsGame); + } else { + rules = *defaultRules; + } + + if (defaultRules == nullptr || rootNode.hasProperty("KM")) { + rules.komi = rootNode.getKomiOrFail(); + } + + vector placementMoves; + rootNode.accumPlacements(placementMoves, xSize, ySize); + vector startPosMoves; + bool randomized; + vector remainingMoves; + rules.startPos = Rules::recognizeStartPos(placementMoves, xSize, ySize, startPosMoves, randomized, &remainingMoves); + if (randomized && !rules.startPosIsRandom) { + propertyFail("Defined start pos is randomized but RU says it shouldn't"); + } + + return rules; +} + bool SgfNode::hasProperty(const char* key) const { if(props == nullptr) return false; @@ -273,13 +302,13 @@ Color SgfNode::getPLSpecifiedColor() const { return C_EMPTY; } -Rules SgfNode::getRulesFromRUTagOrFail() const { +Rules SgfNode::getRulesFromRUTagOrFail(const bool isDots) const { if(!hasProperty("RU")) throw StringError("SGF file does not specify rules"); string s = getSingleProperty("RU"); Rules parsed; - bool suc = Rules::tryParseRules(s,parsed); + bool suc = Rules::tryParseRules(s, parsed, isDots); if(!suc) throw StringError("Could not parse rules in sgf: " + s); return parsed; @@ -399,6 +428,10 @@ static void checkNonEmpty(const vector>& nodes) { throw StringError("Empty sgf"); } +bool Sgf::isDotsGame() const { + return nodes[0]->getIsDotsGame(); +} + XYSize Sgf::getXYSize() const { checkNonEmpty(nodes); int xSize = 0; //Initialize to 0 to suppress spurious clang compiler warning. @@ -424,9 +457,14 @@ XYSize Sgf::getXYSize() const { if(xSize <= 1 || ySize <= 1) propertyFail("Board size in sgf is <= 1: " + s); - if(xSize > Board::MAX_LEN || ySize > Board::MAX_LEN) + if(xSize > Board::MAX_LEN_X) + propertyFail( + "Board size x in sgf is > Board::MAX_LEN_X = " + Global::intToString((int)Board::MAX_LEN_X) + + ", if larger sizes are desired, consider increasing and recompiling: " + s + ); + if(ySize > Board::MAX_LEN_Y) propertyFail( - "Board size in sgf is > Board::MAX_LEN = " + Global::intToString((int)Board::MAX_LEN) + + "Board size y in sgf is > Board::MAX_LEN_Y = " + Global::intToString((int)Board::MAX_LEN_Y) + ", if larger sizes are desired, consider increasing and recompiling: " + s ); return XYSize(xSize,ySize); @@ -447,6 +485,11 @@ float SgfNode::getKomiOrFail() const { return getKomiOrDefault(0.0f); } +bool SgfNode::getIsDotsGame() const { + if(!hasProperty("GM")) return false; + return getSingleProperty("GM") == "40"; +} + float SgfNode::getKomiOrDefault(float defaultKomi) const { //Default, if SGF doesn't specify if(!hasProperty("KM")) @@ -515,9 +558,9 @@ bool Sgf::hasRules() const { Rules Sgf::getRulesOrFail() const { checkNonEmpty(nodes); - Rules rules = nodes[0]->getRulesFromRUTagOrFail(); - rules.komi = getKomiOrFail(); - return rules; + + const XYSize size = getXYSize(); + return getRulesFromSgf(*nodes[0], size.x, size.y, nullptr); } Player Sgf::getSgfWinner() const { @@ -755,7 +798,7 @@ void Sgf::loadAllUniquePositions( Rand* rand, vector& samples ) const { - std::function f = [&samples](PositionSample& sample, const BoardHistory& hist, const string& comments) { + std::function f = [&samples](PositionSample& sample, const BoardHistory& hist, const string& comments) { (void)hist; (void)comments; samples.push_back(sample); @@ -773,17 +816,20 @@ void Sgf::iterAllUniquePositions( Rand* rand, std::function f ) const { + bool isDots = isDotsGame(); XYSize size = getXYSize(); int xSize = size.x; int ySize = size.y; - Board board(xSize,ySize); + Rules rules = Rules::getDefaultOrTrompTaylorish(isDots); Player nextPla = nodes.size() > 0 ? nodes[0]->getPLSpecifiedColor() : C_EMPTY; if(nextPla == C_EMPTY) nextPla = C_BLACK; - Rules rules = Rules::getTrompTaylorish(); - rules.koRule = Rules::KO_SITUATIONAL; - rules.multiStoneSuicideLegal = true; + if (!isDots) { + rules.koRule = Rules::KO_SITUATIONAL; + rules.multiStoneSuicideLegal = true; + } + Board board(xSize,ySize,rules); BoardHistory hist(board,nextPla,rules,0); PositionSample sampleBuf; @@ -800,17 +846,20 @@ void Sgf::iterAllPositions( Rand* rand, std::function f ) const { + bool isDots = isDotsGame(); XYSize size = getXYSize(); int xSize = size.x; int ySize = size.y; - Board board(xSize,ySize); Player nextPla = nodes.size() > 0 ? nodes[0]->getPLSpecifiedColor() : C_EMPTY; if(nextPla == C_EMPTY) nextPla = C_BLACK; - Rules rules = Rules::getTrompTaylorish(); - rules.koRule = Rules::KO_SITUATIONAL; - rules.multiStoneSuicideLegal = true; + Rules rules = Rules::getDefaultOrTrompTaylorish(isDots); + if (!isDots) { + rules.koRule = Rules::KO_SITUATIONAL; + rules.multiStoneSuicideLegal = true; + } + Board board(xSize,ySize,rules); BoardHistory hist(board,nextPla,rules,0); PositionSample sampleBuf; @@ -860,9 +909,9 @@ void Sgf::iterAllPositionsHelper( int netStonesAdded = 0; if(buf.size() > 0) { for(size_t j = 0; j 0x3FFFFFFF) @@ -1089,21 +1138,25 @@ std::set Sgf::readExcludes(const vector& files) { string Sgf::PositionSample::toJsonLine(const Sgf::PositionSample& sample) { json data; - data["xSize"] = sample.board.x_size; - data["ySize"] = sample.board.y_size; - data["board"] = Board::toStringSimple(sample.board,'/'); - data["nextPla"] = PlayerIO::playerToStringShort(sample.nextPla); + const Board& board = sample.board; + if (board.rules.isDots) { + data[DOTS_KEY] = "true"; + } + data["xSize"] = board.x_size; + data["ySize"] = board.y_size; + data["board"] = Board::toStringSimple(board,'/'); + data["nextPla"] = PlayerIO::playerToStringShort(sample.nextPla, board.isDots()); vector moveLocs; vector movePlas; for(size_t i = 0; i 0) data["metadata"] = sample.metadata; @@ -1116,9 +1169,10 @@ Sgf::PositionSample Sgf::PositionSample::ofJsonLine(const string& s) { json data = json::parse(s); PositionSample sample; try { + bool isDots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - sample.board = Board::parseBoard(xSize,ySize,data["board"].get(),'/'); + sample.board = Board::parseBoard(xSize, ySize, data["board"].get(), Rules(isDots), '/'); sample.nextPla = PlayerIO::parsePlayer(data["nextPla"].get()); vector moveLocs = data["moveLocs"].get>(); vector movePlas = data["movePlas"].get>(); @@ -1160,7 +1214,7 @@ Sgf::PositionSample Sgf::PositionSample::ofJsonLine(const string& s) { Sgf::PositionSample Sgf::PositionSample::getColorFlipped() const { Sgf::PositionSample other = *this; - Board newBoard(other.board.x_size,other.board.y_size); + Board newBoard(other.board.x_size, other.board.y_size, other.board.rules); for(int y = 0; y < other.board.y_size; y++) { for(int x = 0; x < other.board.x_size; x++) { Loc loc = Location::getLoc(x,y,other.board.x_size); @@ -1215,7 +1269,9 @@ int64_t Sgf::PositionSample::getCurrentTurnNumber() const { } bool Sgf::PositionSample::isEqualForTesting(const Sgf::PositionSample& other, bool checkNumCaptures, bool checkSimpleKo) const { - if(!board.isEqualForTesting(other.board,checkNumCaptures,checkSimpleKo)) + // Skips the rules check because default `koRule` value `KO_POSITIONAL` differs in `Rules` constructor and in `Sgf::iterAllPositions` (`KO_SITUATIONAL`) + // TODO: fix it + if(!board.isEqualForTesting(other.board,checkNumCaptures,checkSimpleKo,false)) return false; if(nextPla != other.nextPla) return false; @@ -1454,8 +1510,11 @@ static std::unique_ptr maybeParseSgf(const string& str, size_t& pos) { && handicap >= 2 && handicap <= 9 ) { - Board board(19,19); - PlayUtils::placeFixedHandicap(board, handicap); + Board board(rootSgf->getRulesOrFail()); + const auto& handicapLocs = PlayUtils::generateFixedHandicap(board, handicap); + for (const auto loc : handicapLocs) { + board.setStoneFailIfNoLibs(loc, P_BLACK); + } // Older fox sgfs used handicaps with side stones on the north and south rather than east and west if(handicap == 6 || handicap == 7) { if(rootSgf->hasRootProperty("DT")) { @@ -1472,15 +1531,10 @@ static std::unique_ptr maybeParseSgf(const string& str, size_t& pos) { } } - for(int y = 0; yaddRootProperty("AB",out.str()); - } - } + for (const auto loc : handicapLocs) { + ostringstream out; + writeSgfLoc(out, loc, board.x_size, board.y_size); + rootSgf->addRootProperty("AB",out.str()); } } return rootSgf; @@ -1582,6 +1636,7 @@ std::vector> Sgf::loadSgfOrSgfsLogAndIgnoreErrors(const str CompactSgf::CompactSgf(const Sgf& sgf) :fileName(sgf.fileName), + isDots(), rootNode(), placements(), moves(), @@ -1591,6 +1646,7 @@ CompactSgf::CompactSgf(const Sgf& sgf) sgfWinner(), hash(sgf.hash) { + isDots = sgf.isDotsGame(); XYSize size = sgf.getXYSize(); xSize = size.x; ySize = size.y; @@ -1616,6 +1672,7 @@ CompactSgf::CompactSgf(Sgf&& sgf) sgfWinner(), hash(sgf.hash) { + isDots = sgf.isDotsGame(); XYSize size = sgf.getXYSize(); xSize = size.x; ySize = size.y; @@ -1666,21 +1723,11 @@ bool CompactSgf::hasRules() const { } Rules CompactSgf::getRulesOrFail() const { - Rules rules = rootNode.getRulesFromRUTagOrFail(); - rules.komi = rootNode.getKomiOrFail(); - return rules; + return getRulesFromSgf(rootNode, xSize, ySize, nullptr); } Rules CompactSgf::getRulesOrFailAllowUnspecified(const Rules& defaultRules) const { - Rules rules; - if(!hasRules()) - rules = defaultRules; - else - rules = rootNode.getRulesFromRUTagOrFail(); - - if(rootNode.hasProperty("KM")) - rules.komi = rootNode.getKomiOrFail(); - return rules; + return getRulesFromSgf(rootNode, xSize, ySize, &defaultRules); } Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function f) const { @@ -1705,7 +1752,7 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function placementMoves; + rootNode.accumPlacements(placementMoves, xSize, ySize); + vector startPosMoves; + bool randomized; + vector remainingMoves; + rules.startPos = Rules::recognizeStartPos(placementMoves, xSize, ySize, startPosMoves, randomized, &remainingMoves); + if (randomized && !rules.startPosIsRandom) { + f("Defined start pos is randomized but RU says it shouldn't"); + } + return rules; } - -void CompactSgf::setupInitialBoardAndHist(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist) const { +BoardHistory CompactSgf::setupInitialBoardAndHist(const Rules& initialRules, Player& nextPla) const { Color plPlayer = rootNode.getPLSpecifiedColor(); if(plPlayer == P_BLACK || plPlayer == P_WHITE) nextPla = plPlayer; @@ -1754,13 +1811,17 @@ void CompactSgf::setupInitialBoardAndHist(const Rules& initialRules, Board& boar if(moves.size() > 0) nextPla = moves[0].pla; - board = Board(xSize,ySize); - bool suc = board.setStonesFailIfNoLibs(placements); - if(!suc) - throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); - hist = BoardHistory(board,nextPla,initialRules,0); - if(hist.initialTurnNumber < board.numStonesOnBoard()) - hist.initialTurnNumber = board.numStonesOnBoard(); + auto board = Board(xSize,ySize,initialRules); + const int numOfStartPosStones = initialRules.getNumOfStartPosStones(); + for (int i = 0; i < placements.size(); i++) { + const Move placement = placements[i]; + if(const bool suc = board.setStoneFailIfNoLibs(placement.loc, placement.pla, i < numOfStartPosStones); !suc) + throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); + } + auto hist = BoardHistory(board,nextPla,initialRules,0); + if (const int numStonesOnBoard = board.numStonesOnBoard(); hist.initialTurnNumber < numStonesOnBoard) + hist.initialTurnNumber = numStonesOnBoard; + return hist; } void CompactSgf::playMovesAssumeLegal(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const { @@ -1790,19 +1851,28 @@ void CompactSgf::playMovesTolerant(Board& board, Player& nextPla, BoardHistory& for(int64_t i = 0; i CompactSgf::setupBoardAndHistAssumeLegal(const Rules& initialRules, Player& nextPla, int64_t turnIdx) + const { + BoardHistory hist = setupInitialBoardAndHist(initialRules, nextPla); + Board boardWithMoves(hist.initialBoard); + playMovesAssumeLegal(boardWithMoves, nextPla, hist, turnIdx); + return std::make_pair(hist, boardWithMoves); } -void CompactSgf::setupBoardAndHistTolerant(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const { - setupInitialBoardAndHist(initialRules, board, nextPla, hist); - playMovesTolerant(board, nextPla, hist, turnIdx, preventEncore); +std::pair CompactSgf::setupBoardAndHistTolerant( + const Rules& initialRules, + Player& nextPla, + int64_t turnIdx, + bool preventEncore) const { + BoardHistory hist = setupInitialBoardAndHist(initialRules, nextPla); + Board boardWithMoves(hist.initialBoard); + playMovesTolerant(boardWithMoves, nextPla, hist, turnIdx, preventEncore); + return std::make_pair(hist, boardWithMoves); } @@ -1924,7 +1994,10 @@ void WriteSgf::writeSgf( int xSize = initialBoard.x_size; int ySize = initialBoard.y_size; - out << "(;FF[4]GM[1]"; + out << "(;FF[4]"; + out << "AP[katago]"; + + out << "GM[" << (rules.isDots ? "40" : "1") << "]"; if(xSize == ySize) out << "SZ[" << xSize << "]"; else @@ -1932,25 +2005,29 @@ void WriteSgf::writeSgf( out << "PB[" << bName << "]"; out << "PW[" << wName << "]"; - if(gameData != NULL) { - out << "HA[" << gameData->handicapForSgf << "]"; + int handicap = 0; + if(gameData != nullptr) { + handicap = gameData->handicapForSgf; } else { BoardHistory histCopy(endHist); //Always use true for computing the handicap value that goes into an sgf histCopy.setAssumeMultipleStartingBlackMovesAreHandicap(true); - out << "HA[" << histCopy.computeNumHandicapStones() << "]"; + handicap = histCopy.computeNumHandicapStones(); + } + if (!rules.isDots || handicap != 0) { // Preserve backward compatibility and always fill `HA` for Go Game + out << "HA[" << handicap << "]"; } out << "KM[" << rules.komi << "]"; - out << "RU[" << (tryNicerRulesString ? rules.toStringNoKomiMaybeNice() : rules.toStringNoKomi()) << "]"; + out << "RU[" << (tryNicerRulesString ? rules.toStringNoSgfDefinedPropertiesMaybeNice() : rules.toString(false)) << "]"; printGameResult(out,endHist,overrideFinishedWhiteScore); bool hasAB = false; for(int y = 0; ywhiteValueTargetsByTurn.size()); } + if (endHist.isPassAliveFinished) { + commentOut << "," << "passAliveFinished=true"; + } + if(extraComments.size() > 0) { if(commentOut.str().length() > 0) commentOut << " "; diff --git a/cpp/dataio/sgf.h b/cpp/dataio/sgf.h index 467a34633..53b855653 100644 --- a/cpp/dataio/sgf.h +++ b/cpp/dataio/sgf.h @@ -39,9 +39,10 @@ struct SgfNode { void accumMoves(std::vector& moves, int xSize, int ySize) const; Color getPLSpecifiedColor() const; - Rules getRulesFromRUTagOrFail() const; + Rules getRulesFromRUTagOrFail(bool isDots) const; Player getSgfWinner() const; float getKomiOrFail() const; + bool getIsDotsGame() const; float getKomiOrDefault(float defaultKomi) const; std::string getPlayerName(Player pla) const; @@ -69,6 +70,7 @@ struct Sgf { static std::vector> loadSgfOrSgfsLogAndIgnoreErrors(const std::string& file, Logger& logger); + bool isDotsGame() const; XYSize getXYSize() const; float getKomiOrFail() const; float getKomiOrDefault(float defaultKomi) const; @@ -216,6 +218,7 @@ struct Sgf { struct CompactSgf { std::string fileName; + bool isDots; SgfNode rootNode; std::vector placements; std::vector moves; @@ -241,12 +244,12 @@ struct CompactSgf { Rules getRulesOrFailAllowUnspecified(const Rules& defaultRules) const; Rules getRulesOrWarn(const Rules& defaultRules, std::function f) const; - void setupInitialBoardAndHist(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist) const; + BoardHistory setupInitialBoardAndHist(const Rules& initialRules, Player& nextPla) const; void playMovesAssumeLegal(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const; - void setupBoardAndHistAssumeLegal(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const; + std::pair setupBoardAndHistAssumeLegal(const Rules& initialRules, Player& nextPla, int64_t turnIdx) const; //These throw a StringError upon illegal move. void playMovesTolerant(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const; - void setupBoardAndHistTolerant(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const; + std::pair setupBoardAndHistTolerant(const Rules& initialRules, Player& nextPla, int64_t turnIdx, bool preventEncore) const; }; namespace WriteSgf { diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index a8b03bd25..a9b0db48a 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -18,9 +18,9 @@ ValueTargets::~ValueTargets() //------------------------------------------------------------------------------------- -SidePosition::SidePosition() +SidePosition::SidePosition(const Rules& rules) :board(), - hist(), + hist(rules), pla(P_BLACK), unreducedNumVisits(), policyTarget(), @@ -36,7 +36,7 @@ SidePosition::SidePosition() playoutDoublingAdvantage(0.0) {} -SidePosition::SidePosition(const Board& b, const BoardHistory& h, Player p, int numNNChangesSoFar) +SidePosition::SidePosition(const Board& b, const BoardHistory& h, const Player p, const int numNNChangesSoFar) :board(b), hist(h), pla(p), @@ -59,15 +59,15 @@ SidePosition::~SidePosition() //------------------------------------------------------------------------------------- -FinishedGameData::FinishedGameData() +FinishedGameData::FinishedGameData(const Rules& rules) :bName(), wName(), bIdx(0), wIdx(0), - startBoard(), - startHist(), - endHist(), + startBoard(rules), + startHist(rules), + endHist(rules), startPla(P_BLACK), gameHash(), @@ -265,7 +265,9 @@ static const int GLOBAL_TARGET_NUM_CHANNELS = 64; static const int VALUE_SPATIAL_TARGET_NUM_CHANNELS = 5; static const int QVALUE_SPATIAL_TARGET_NUM_CHANNELS = 3; -TrainingWriteBuffers::TrainingWriteBuffers(int iVersion, int maxRws, int numBChannels, int numFChannels, int xLen, int yLen, bool includeMetadata) +TrainingWriteBuffers::TrainingWriteBuffers( + const int iVersion, int maxRws, int numBChannels, int numFChannels, int xLen, int yLen, + const bool includeMetadata) :inputsVersion(iVersion), maxRows(maxRws), numBinaryChannels(numBChannels), @@ -298,7 +300,7 @@ void TrainingWriteBuffers::clear() { } //Copy floats that are all 0-1 into bits, packing 8 to a byte, big-endian-style within each byte. -static void packBits(const float* binaryFloats, int len, uint8_t* bits) { +static void packBits(const float* binaryFloats, const int len, uint8_t* bits) { for(int i = 0; i < len; i += 8) { if(i + 8 <= len) { bits[i >> 3] = @@ -320,18 +322,22 @@ static void packBits(const float* binaryFloats, int len, uint8_t* bits) { } } -static void zeroPolicyTarget(int policySize, int16_t* target) { +static void zeroPolicyTarget(const int policySize, int16_t* target) { for(int pos = 0; pos& policyTargetMoves, int policySize, int dataXLen, int dataYLen, int boardXSize, int16_t* target) { +static void fillPolicyTarget(const vector& policyTargetMoves, + const int policySize, + const int dataXLen, + const int dataYLen, + const int boardXSize, int16_t* target) { zeroPolicyTarget(policySize,target); size_t size = policyTargetMoves.size(); for(size_t i = 0; i& policyTargetMoves, //Clamps a value to integer in [-120,120] to pack down to 8 bits. //Randomizes to make sure the expectation is exactly correct. -static int8_t clampToRadius120(float x, Rand& rand) { +static int8_t clampToRadius120(const float x, Rand& rand) { //We need to pack this down to 8 bits, so map into [-120,120]. //Randomize to ensure the expectation is exactly correct. int low = (int)floor(x); @@ -356,7 +362,7 @@ static int8_t clampToRadius120(float x, Rand& rand) { if(lambda == 0.0f) return (int8_t)low; else return (int8_t)(rand.nextBool(lambda) ? high : low); } -static int16_t clampToRadius32000(float x, Rand& rand) { +static int16_t clampToRadius32000(const float x, Rand& rand) { //We need to pack this down to 16 bits, so clamp into an integer [-32000,32000]. //Randomize to ensure the expectation is exactly correct. int low = (int)floor(x); @@ -369,7 +375,12 @@ static int16_t clampToRadius32000(float x, Rand& rand) { else return (int16_t)(rand.nextBool(lambda) ? high : low); } -static void fillQValueTarget(const vector& whiteQValueTargets, Player nextPlayer, int policySize, int dataXLen, int dataYLen, int boardXSize, int16_t* cPosTarget, Rand& rand) { +static void fillQValueTarget(const vector& whiteQValueTargets, + const Player nextPlayer, + const int policySize, + const int dataXLen, + const int dataYLen, + const int boardXSize, int16_t* cPosTarget, Rand& rand) { for(int i = 0; i < QVALUE_SPATIAL_TARGET_NUM_CHANNELS * policySize; i++) { cPosTarget[i] = 0; } @@ -395,7 +406,10 @@ static void fillQValueTarget(const vector& whiteQValueTargets, } } -static void fillValueTDTargets(const vector& whiteValueTargetsByTurn, int idx, Player nextPlayer, double nowFactor, float* buf) { +static void fillValueTDTargets(const vector& whiteValueTargetsByTurn, + const int idx, + const Player nextPlayer, + const double nowFactor, float* buf) { double winValue = 0.0; double lossValue = 0.0; double noResultValue = 0.0; @@ -490,36 +504,13 @@ void TrainingWriteBuffers::addRow( bool inputsUseNHWC = false; float* rowBin = binaryInputNCHWUnpacked; float* rowGlobal = globalInputNC.data + curRows * numGlobalChannels; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V3 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V3 == numGlobalChannels); - NNInputs::fillRowV3(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 4) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V4 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V4 == numGlobalChannels); - NNInputs::fillRowV4(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 5) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V5 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V5 == numGlobalChannels); - NNInputs::fillRowV5(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 6) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V6 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V6 == numGlobalChannels); - NNInputs::fillRowV6(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 7) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V7 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V7 == numGlobalChannels); - NNInputs::fillRowV7(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else - ASSERT_UNREACHABLE; - //Pack bools bitwise into uint8_t + assert(NNInputs::getNumberOfSpatialFeatures(inputsVersion, hist.rules.isDots) == numBinaryChannels); + assert(NNInputs::getNumberOfGlobalFeatures(inputsVersion, hist.rules.isDots) == numGlobalChannels); + + NNInputs::fillRowVN(inputsVersion, board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); + + // Pack bools bitwise into uint8_t uint8_t* rowBinPacked = binaryInputNCHWPacked.data + curRows * numBinaryChannels * packedBoardArea; for(int c = 0; c engage for Dots game + //Target weight for the whole row rowGlobal[25] = targetWeight; @@ -601,8 +594,6 @@ void TrainingWriteBuffers::addRow( rowGlobal[22] = (float)sum; } - //Unused - rowGlobal[23] = 0.0f; rowGlobal[24] = (float)(1.0f - tdValueTargetWeight); rowGlobal[30] = (float)policySurprise; rowGlobal[31] = (float)policyEntropy; @@ -694,7 +685,11 @@ void TrainingWriteBuffers::addRow( rowScoreDistr[scoreDistrMid] = 50; } else { - assert(finalFullArea != NULL); + if (!hist.rules.isDots) { + assert(finalFullArea != NULL); + } else { + assert(finalFullArea == NULL); + } assert(finalBoard != NULL); //Ownership weight scales by value weight @@ -716,9 +711,11 @@ void TrainingWriteBuffers::addRow( Loc loc = Location::getLoc(x,y,board.x_size); if(finalOwnership[loc] == nextPlayer) rowOwnership[pos] = 1; else if(finalOwnership[loc] == opp) rowOwnership[pos] = -1; - //Mark full area points that ended up not being owned - if(finalFullArea[loc] != C_EMPTY && finalOwnership[loc] == C_EMPTY) - rowOwnership[pos+posArea] = (finalFullArea[loc] == nextPlayer ? 1 : -1); + if (!hist.rules.isDots) { + //Mark full area points that ended up not being owned + if(finalFullArea[loc] != C_EMPTY && finalOwnership[loc] == C_EMPTY) + rowOwnership[pos+posArea] = (finalFullArea[loc] == nextPlayer ? 1 : -1); + } } } @@ -770,10 +767,12 @@ void TrainingWriteBuffers::addRow( for(int x = 0; x #include -#include -#include +#include +#include #include -#include "../core/rand.h" - using namespace std; //STATIC VARS----------------------------------------------------------------------------- bool Board::IS_ZOBRIST_INITALIZED = false; -Hash128 Board::ZOBRIST_SIZE_X_HASH[MAX_LEN+1]; -Hash128 Board::ZOBRIST_SIZE_Y_HASH[MAX_LEN+1]; +Hash128 Board::ZOBRIST_SIZE_X_HASH[MAX_LEN_X+1]; +Hash128 Board::ZOBRIST_SIZE_Y_HASH[MAX_LEN_Y+1]; Hash128 Board::ZOBRIST_BOARD_HASH[MAX_ARR_SIZE][4]; Hash128 Board::ZOBRIST_PLAYER_HASH[4]; Hash128 Board::ZOBRIST_KO_LOC_HASH[MAX_ARR_SIZE]; @@ -45,16 +43,27 @@ int Location::getY(Loc loc, int x_size) { return (loc / (x_size+1)) - 1; } -void Location::getAdjacentOffsets(short adj_offsets[8], int x_size) -{ - adj_offsets[0] = -(x_size+1); - adj_offsets[1] = -1; - adj_offsets[2] = 1; - adj_offsets[3] = (x_size+1); - adj_offsets[4] = -(x_size+1)-1; - adj_offsets[5] = -(x_size+1)+1; - adj_offsets[6] = (x_size+1)-1; - adj_offsets[7] = (x_size+1)+1; +void Location::getAdjacentOffsets(short adj_offsets[8], int x_size, bool isDots) { + int stride = x_size + 1; + if (isDots) { + adj_offsets[LEFT_TOP_INDEX] = -stride - 1; + adj_offsets[TOP_INDEX] = -stride; + adj_offsets[RIGHT_TOP_INDEX] = -stride + 1; + adj_offsets[RIGHT_INDEX] = +1; + adj_offsets[RIGHT_BOTTOM_INDEX] = +stride + 1; + adj_offsets[BOTTOM_INDEX] = +stride; + adj_offsets[LEFT_BOTTOM_INDEX] = +stride - 1; + adj_offsets[LEFT_INDEX] = -1; + } else { + adj_offsets[0] = -stride; + adj_offsets[1] = -1; + adj_offsets[2] = 1; + adj_offsets[3] = stride; + adj_offsets[4] = -stride-1; + adj_offsets[5] = -stride+1; + adj_offsets[6] = stride-1; + adj_offsets[7] = stride+1; + } } bool Location::isAdjacent(Loc loc0, Loc loc1, int x_size) @@ -63,7 +72,7 @@ bool Location::isAdjacent(Loc loc0, Loc loc1, int x_size) } Loc Location::getMirrorLoc(Loc loc, int x_size, int y_size) { - if(loc == Board::NULL_LOC || loc == Board::PASS_LOC) + if(loc == Board::NULL_LOC || loc == Board::PASS_LOC || loc == Board::RESIGN_LOC) return loc; return getLoc(x_size-1-getX(loc,x_size),y_size-1-getY(loc,x_size),x_size); } @@ -90,56 +99,74 @@ bool Location::isNearCentral(Loc loc, int x_size, int y_size) { return x >= (x_size-1)/2-1 && x <= x_size/2+1 && y >= (y_size-1)/2-1 && y <= y_size/2+1; } - -#define FOREACHADJ(BLOCK) {int ADJOFFSET = -(x_size+1); {BLOCK}; ADJOFFSET = -1; {BLOCK}; ADJOFFSET = 1; {BLOCK}; ADJOFFSET = x_size+1; {BLOCK}}; #define ADJ0 (-(x_size+1)) #define ADJ1 (-1) #define ADJ2 (1) #define ADJ3 (x_size+1) -//CONSTRUCTORS AND INITIALIZATION---------------------------------------------------------- +// CONSTRUCTORS AND INITIALIZATION---------------------------------------------------------- -Board::Board() -{ - init(DEFAULT_LEN,DEFAULT_LEN); +Board::Base::Base(Player newPla, + const std::vector& rollbackLocations, + const std::vector& rollbackStates, + const bool isReal +) { + pla = newPla; + rollback_locations = rollbackLocations; + rollback_states = rollbackStates; + is_real = isReal; } -Board::Board(int x, int y) -{ - init(x,y); +Board::Board() : Board(Rules::DEFAULT_GO) {} + +Board::Board(const Rules& newRules) { + init(newRules.isDots ? DEFAULT_LEN_X_DOTS : DEFAULT_LEN_X, newRules.isDots ? DEFAULT_LEN_Y_DOTS : DEFAULT_LEN_Y, newRules); } +Board::Board(const int x, const int y, const Rules& newRules) { + init(x, y, newRules); +} -Board::Board(const Board& other) -{ +Board::Board(const Board& other) { x_size = other.x_size; y_size = other.y_size; + rules = other.rules; memcpy(colors, other.colors, sizeof(Color)*MAX_ARR_SIZE); - memcpy(chain_data, other.chain_data, sizeof(ChainData)*MAX_ARR_SIZE); - memcpy(chain_head, other.chain_head, sizeof(Loc)*MAX_ARR_SIZE); - memcpy(next_in_chain, other.next_in_chain, sizeof(Loc)*MAX_ARR_SIZE); - ko_loc = other.ko_loc; + + chain_data = other.chain_data; + chain_head = other.chain_head; + next_in_chain = other.next_in_chain; + // empty_list = other.empty_list; pos_hash = other.pos_hash; numBlackCaptures = other.numBlackCaptures; numWhiteCaptures = other.numWhiteCaptures; - + blackScoreIfWhiteGrounds = other.blackScoreIfWhiteGrounds; + whiteScoreIfBlackGrounds = other.whiteScoreIfBlackGrounds; + numLegalMovesIfSuiAllowed = other.numLegalMovesIfSuiAllowed; + start_pos_moves = other.start_pos_moves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); + visited_data.resize(other.visited_data.size(), false); } -void Board::init(int xS, int yS) +void Board::init(const int xS, const int yS, const Rules& initRules) { assert(IS_ZOBRIST_INITALIZED); - if(xS < 0 || yS < 0 || xS > MAX_LEN || yS > MAX_LEN) + if(xS < 0 || yS < 0 || xS > MAX_LEN_X || yS > MAX_LEN_Y) throw StringError("Board::init - invalid board size"); x_size = xS; y_size = yS; + rules = initRules; - for(int i = 0; i < MAX_ARR_SIZE; i++) + for(int i = 0; i < MAX_ARR_SIZE; i++) { colors[i] = C_WALL; + if (rules.isDots) { + setGrounded(i); + } + } for(int y = 0; y < y_size; y++) { @@ -155,8 +182,19 @@ void Board::init(int xS, int yS) pos_hash = ZOBRIST_SIZE_X_HASH[x_size] ^ ZOBRIST_SIZE_Y_HASH[y_size]; numBlackCaptures = 0; numWhiteCaptures = 0; + blackScoreIfWhiteGrounds = 0; + whiteScoreIfBlackGrounds = 0; + numLegalMovesIfSuiAllowed = xS * yS; - Location::getAdjacentOffsets(adj_offsets,x_size); + if (!rules.isDots) { + chain_data.resize(MAX_ARR_SIZE); + chain_head.resize(MAX_ARR_SIZE); + next_in_chain.resize(MAX_ARR_SIZE); + } else { + visited_data.resize(getMaxArrSize(x_size, y_size), false); + } + + Location::getAdjacentOffsets(adj_offsets, x_size, isDots()); } void Board::initHash() @@ -208,9 +246,11 @@ void Board::initHash() //Reseed the random number generator so that these size hashes are also //not affected by the size of the board we compile with rand.init("Board::initHash() for ZOBRIST_SIZE hashes"); - for(int i = 0; i(colors[loc] & ACTIVE_MASK); +} + +Color Board::getPlacedColor(const Loc loc) const { + const State state = getState(loc); + return rules.isDots ? getPlacedDotColor(state) : state; +} double Board::sqrtBoardArea() const { if (x_size == y_size) { @@ -266,8 +314,14 @@ int Board::getNumLiberties(Loc loc) const //Check if moving here would be a self-capture bool Board::isSuicide(Loc loc, Player pla) const { - if(loc == PASS_LOC) + if (loc == PASS_LOC) return false; + if (loc == RESIGN_LOC) + return true; // It's seems reasonable to treat resigning as suicide + + if (rules.isDots) { + return isSuicideDots(loc, pla); + } Player opp = getOpp(pla); FOREACHADJ( @@ -293,6 +347,10 @@ bool Board::isSuicide(Loc loc, Player pla) const //Check if moving here is would be an illegal self-capture bool Board::isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const { + if (rules.isDots) { + return !isMultiStoneSuicideLegal && isSuicideDots(loc, pla); + } + Player opp = getOpp(pla); FOREACHADJ( Loc adj = loc + ADJOFFSET; @@ -317,6 +375,7 @@ bool Board::isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) //Returns a fast lower bound on the number of liberties a new stone placed here would have void Board::getBoundNumLibertiesAfterPlay(Loc loc, Player pla, int& lowerBound, int& upperBound) const { + assert(!isDots()); Player opp = getOpp(pla); int numImmediateLibs = 0; //empty spaces adjacent @@ -354,6 +413,7 @@ void Board::getBoundNumLibertiesAfterPlay(Loc loc, Player pla, int& lowerBound, //Returns the number of liberties a new stone placed here would have, or max if it would be >= max. int Board::getNumLibertiesAfterPlay(Loc loc, Player pla, int max) const { + assert(!isDots()); Player opp = getOpp(pla); int numLibs = 0; @@ -445,32 +505,22 @@ bool Board::isKoBanned(Loc loc) const } bool Board::isOnBoard(Loc loc) const { - return loc >= 0 && loc < MAX_ARR_SIZE && colors[loc] != C_WALL; + return loc >= 0 && loc < MAX_ARR_SIZE && getColor(loc) != C_WALL; } //Check if moving here is illegal. -bool Board::isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const +bool Board::isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal, const bool ignoreKo) const { if(pla != P_BLACK && pla != P_WHITE) return false; - return loc == PASS_LOC || ( - loc >= 0 && - loc < MAX_ARR_SIZE && - (colors[loc] == C_EMPTY) && - !isKoBanned(loc) && - !isIllegalSuicide(loc, pla, isMultiStoneSuicideLegal) - ); -} - -//Check if moving here is illegal, ignoring simple ko -bool Board::isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const -{ - if(pla != P_BLACK && pla != P_WHITE) - return false; - return loc == PASS_LOC || ( + if (loc == RESIGN_LOC) { + return true; + } + return loc == PASS_LOC || loc == RESIGN_LOC || ( loc >= 0 && loc < MAX_ARR_SIZE && - (colors[loc] == C_EMPTY) && + getColor(loc) == C_EMPTY && + (rules.isDots || ignoreKo || !isKoBanned(loc)) && !isIllegalSuicide(loc, pla, isMultiStoneSuicideLegal) ); } @@ -478,6 +528,8 @@ bool Board::isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal //Check if this location contains a simple eye for the specified player. bool Board::isSimpleEye(Loc loc, Player pla) const { + assert(!rules.isDots); + if(colors[loc] != C_EMPTY) return false; @@ -509,9 +561,14 @@ bool Board::isSimpleEye(Loc loc, Player pla) const return true; } -bool Board::wouldBeCapture(Loc loc, Player pla) const { - if(colors[loc] != C_EMPTY) +bool Board::wouldBeCapture(const Loc loc, const Player pla) const { + if(getColor(loc) != C_EMPTY) return false; + + if (rules.isDots) { + return wouldBeCaptureDots(loc, pla); + } + Player opp = getOpp(pla); FOREACHADJ( Loc adj = loc + ADJOFFSET; @@ -525,8 +582,11 @@ bool Board::wouldBeCapture(Loc loc, Player pla) const { return false; } - bool Board::wouldBeKoCapture(Loc loc, Player pla) const { + if (isDots()) { + return false; // Ko is not relevant for Dots + } + if(colors[loc] != C_EMPTY) return false; //Check that surounding points are are all opponent owned and exactly one of them is capturable @@ -581,16 +641,16 @@ Loc Board::getKoCaptureLoc(Loc loc, Player pla) const { bool Board::isAdjacentToPla(Loc loc, Player pla) const { FOREACHADJ( Loc adj = loc + ADJOFFSET; - if(colors[adj] == pla) + if(getColor(adj) == pla) return true; ); return false; } bool Board::isAdjacentOrDiagonalToPla(Loc loc, Player pla) const { - for(int i = 0; i<8; i++) { + for(int i = 0; i < 8; i++) { Loc adj = loc + adj_offsets[i]; - if(colors[adj] == pla) + if(getColor(adj) == pla) return true; } return false; @@ -636,39 +696,66 @@ bool Board::isNonPassAliveSelfConnection(Loc loc, Player pla, Color* passAliveAr return false; } -bool Board::isEmpty() const { - for(int y = 0; y < y_size; y++) { - for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] != C_EMPTY) - return false; - } - } - return true; +bool Board::isStartPos() const { + int startBoardNumBlackStones, startBoardNumWhiteStones; + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, false); + return startBoardNumBlackStones == 0 && startBoardNumWhiteStones == 0; } int Board::numStonesOnBoard() const { - int num = 0; - for(int y = 0; y < y_size; y++) { - for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] == C_BLACK || colors[loc] == C_WHITE) - num += 1; - } - } - return num; + int startBoardNumBlackStones, startBoardNumWhiteStones; + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, true); + return startBoardNumBlackStones + startBoardNumWhiteStones; } int Board::numPlaStonesOnBoard(Player pla) const { - int num = 0; + int startBoardNumBlackStones, startBoardNumWhiteStones; + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, true); + return pla == C_BLACK ? startBoardNumBlackStones : startBoardNumWhiteStones; +} + +vector Board::getCurrentMoves( + int& startBoardNumBlackStones, + int& startBoardNumWhiteStones, + const bool includeStartLocs) const { + + vector stones; + startBoardNumBlackStones = 0; + startBoardNumWhiteStones = 0; + + // Ignore start pos moves that are generated according to rules + set startLocs; + for(auto move: start_pos_moves) { + startLocs.insert(move.loc); + if (includeStartLocs) { + // Fill start pos stones as in priority + stones.emplace_back(move); + } + } + for(int y = 0; y < y_size; y++) { for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] == pla) - num += 1; + if(const Loc loc = Location::getLoc(x, y, x_size)) { + bool isStartPosLoc = startLocs.count(loc) > 0; + if (includeStartLocs || !isStartPosLoc) { + if(const Color color = getPlacedColor(loc); color == C_BLACK) { + startBoardNumBlackStones += 1; + if (!isStartPosLoc) { + stones.emplace_back(loc, C_BLACK); + } + } + else if(color == C_WHITE) { + startBoardNumWhiteStones += 1; + if (!isStartPosLoc) { + stones.emplace_back(loc, C_WHITE); + } + } + } + } } } - return num; + + return stones; } bool Board::setStone(Loc loc, Color color) @@ -696,25 +783,32 @@ bool Board::setStone(Loc loc, Color color) return true; } -bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { - if(loc < 0 || loc >= MAX_ARR_SIZE || colors[loc] == C_WALL) +bool Board::setStoneFailIfNoLibs(Loc loc, Color color, const bool startPos) { + const Color colorAtLoc = getColor(loc); + if(loc < 0 || loc >= MAX_ARR_SIZE || colorAtLoc == C_WALL) return false; if(color != C_BLACK && color != C_WHITE && color != C_EMPTY) return false; Loc oldKoLoc = ko_loc; - if(colors[loc] == color) + if(colorAtLoc == color) {} - else if(colors[loc] == C_EMPTY) { + else if(colorAtLoc == C_EMPTY) { if(isSuicide(loc,color) || wouldBeCapture(loc,color)) return false; playMoveAssumeLegal(loc,color); + if (startPos) { + start_pos_moves.emplace_back(loc, color); + } } else if(color == C_EMPTY) removeSingleStone(loc); else { - assert(colors[loc] == getOpp(color)); + assert(colorAtLoc == getOpp(color)); removeSingleStone(loc); + if (startPos) { + start_pos_moves.emplace_back(loc, color); + } if(isSuicide(loc,color) || wouldBeCapture(loc,color)) { playMoveAssumeLegal(loc,getOpp(color)); ko_loc = oldKoLoc; @@ -727,25 +821,31 @@ bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { return true; } -bool Board::setStonesFailIfNoLibs(std::vector placements) { +void Board::setStartPos(Rand& rand) { + const vector startPos = Rules::generateStartPos(rules.startPos, rules.startPosIsRandom ? &rand : nullptr, x_size, y_size); + const bool success = setStonesFailIfNoLibs(startPos, true); + assert(success); +} + +bool Board::setStonesFailIfNoLibs(const std::vector& placements, const bool startPos) { std::set locs; for(const Move& placement: placements) { if(locs.find(placement.loc) != locs.end()) return false; locs.insert(placement.loc); } + //First empty out all locations that we plan to set. //This guarantees avoiding any intermediate liberty issues. for(const Move& placement: placements) { - bool suc = setStoneFailIfNoLibs(placement.loc, C_EMPTY); - if(!suc) + if(bool suc = setStoneFailIfNoLibs(placement.loc, C_EMPTY, startPos); !suc) return false; } //Now set all the stones we wanted. for(const Move& placement: placements) { - bool suc = setStoneFailIfNoLibs(placement.loc, placement.pla); - if(!suc) + if(bool suc = setStoneFailIfNoLibs(placement.loc, placement.pla, startPos); !suc) { return false; + } } return true; } @@ -753,7 +853,7 @@ bool Board::setStonesFailIfNoLibs(std::vector placements) { //Attempts to play the specified move. Returns true if successful, returns false if the move was illegal. bool Board::playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal) { - if(isLegal(loc,pla,isMultiStoneSuicideLegal)) + if(isLegal(loc, pla, isMultiStoneSuicideLegal, false)) { playMoveAssumeLegal(loc,pla); return true; @@ -762,47 +862,54 @@ bool Board::playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal) } //Plays the specified move, assuming it is legal, and returns a MoveRecord for the move -Board::MoveRecord Board::playMoveRecorded(Loc loc, Player pla) -{ - MoveRecord record; - record.loc = loc; - record.pla = pla; - record.ko_loc = ko_loc; - record.capDirs = 0; +Board::MoveRecord Board::playMoveRecorded(const Loc loc, const Player pla) { + if (rules.isDots) { + return playMoveRecordedDots(loc, pla); + } + + uint8_t capDirs = 0; - if(loc != PASS_LOC) { + if(loc != PASS_LOC && loc != RESIGN_LOC) { Player opp = getOpp(pla); { int adj = loc + ADJ0; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 0); } + capDirs |= (((uint8_t)1) << 0); } { int adj = loc + ADJ1; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 1); } + capDirs |= (((uint8_t)1) << 1); } { int adj = loc + ADJ2; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 2); } + capDirs |= (((uint8_t)1) << 2); } { int adj = loc + ADJ3; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 3); } + capDirs |= (((uint8_t)1) << 3); } - if(record.capDirs == 0 && isSuicide(loc,pla)) - record.capDirs = 0x10; + if(capDirs == 0 && isSuicide(loc,pla)) + capDirs = 0x10; } + const Loc rollback_ko_loc = ko_loc; + playMoveAssumeLegal(loc, pla); - return record; + + return MoveRecord(loc, pla, rollback_ko_loc, capDirs); } //Undo the move given by record. Moves MUST be undone in the order they were made. //Undos will NOT typically restore the precise representation in the board to the way it was. The heads of chains //might change, the order of the circular lists might change, etc. -void Board::undo(Board::MoveRecord record) +void Board::undo(MoveRecord& record) { + if (rules.isDots) { + undoDots(record); + return; + } + ko_loc = record.ko_loc; Loc loc = record.loc; - if(loc == PASS_LOC) + if(loc == PASS_LOC || loc == RESIGN_LOC) return; //Re-fill stones in all captured directions @@ -918,7 +1025,7 @@ void Board::undo(Board::MoveRecord record) } Hash128 Board::getPosHashAfterMove(Loc loc, Player pla) const { - if(loc == PASS_LOC) + if(loc == PASS_LOC || loc == RESIGN_LOC) return pos_hash; assert(loc != NULL_LOC); @@ -998,10 +1105,14 @@ Hash128 Board::getPosHashAfterMove(Loc loc, Player pla) const { } //Plays the specified move, assuming it is legal. -void Board::playMoveAssumeLegal(Loc loc, Player pla) -{ - //Pass? - if(loc == PASS_LOC) +void Board::playMoveAssumeLegal(Loc loc, Player pla) { + if (rules.isDots) { + playMoveAssumeLegalDots(loc, pla); + return; + } + + // Pass or resign? + if(loc == PASS_LOC || loc == RESIGN_LOC) { ko_loc = NULL_LOC; return; @@ -1218,6 +1329,9 @@ int Board::removeChain(Loc loc) //Remove a single stone, even a stone part of a larger group. void Board::removeSingleStone(Loc loc) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } Player pla = colors[loc]; //Save the entire chain's stone locations @@ -1432,7 +1546,7 @@ int Location::euclideanDistanceSquared(Loc loc0, Loc loc1, int x_size) { return dx*dx + dy*dy; } -//TACTICAL STUFF-------------------------------------------------------------------- +// TACTICAL STUFF-------------------------------------------------------------------- //Helper, find liberties of group at loc. Fills in buf, returns the number of liberties. //bufStart is where to start checking to avoid duplicates. bufIdx is where to start actually writing. @@ -1529,6 +1643,10 @@ bool Board::hasLibertyGainingCaptures(Loc loc) const { } bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, vector& workingMoves) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } + if(loc < 0 || loc >= MAX_ARR_SIZE) return false; if(colors[loc] != C_BLACK && colors[loc] != C_WHITE) @@ -1553,12 +1671,12 @@ bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, //Attacker: A suicide move cannot reduce the defender's liberties //Defender: A suicide move cannot gain liberties bool isMultiStoneSuicideLegal = false; - if(isLegal(move0,opp,isMultiStoneSuicideLegal)) { + if(isLegal(move0, opp, isMultiStoneSuicideLegal, false)) { MoveRecord record = playMoveRecorded(move0,opp); move0Works = searchIsLadderCaptured(loc,true,buf); undo(record); } - if(isLegal(move1,opp,isMultiStoneSuicideLegal)) { + if(isLegal(move1, opp, isMultiStoneSuicideLegal, false)) { MoveRecord record = playMoveRecorded(move1,opp); move1Works = searchIsLadderCaptured(loc,true,buf); undo(record); @@ -1576,6 +1694,10 @@ bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, } bool Board::searchIsLadderCaptured(Loc loc, bool defenderFirst, vector& buf) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } + if(loc < 0 || loc >= MAX_ARR_SIZE) return false; if(colors[loc] != C_BLACK && colors[loc] != C_WHITE) @@ -1780,7 +1902,7 @@ bool Board::searchIsLadderCaptured(Loc loc, bool defenderFirst, vector& buf //Illegal move - treat it the same as a failed move, but don't return up a level so that we //loop again and just try the next move. bool isMultiStoneSuicideLegal = false; - if(!isLegal(move,p,isMultiStoneSuicideLegal)) { + if(!isLegal(move, p, isMultiStoneSuicideLegal, false)) { returnValue = isDefender; returnedFromDeeper = false; // if(print) cout << "illegal " << endl; @@ -1807,6 +1929,11 @@ void Board::calculateArea( bool unsafeBigTerritories, bool isMultiStoneSuicideLegal ) const { + if (rules.isDots) { + calculateOwnershipAndWhiteScore(result, C_EMPTY); + return; + } + std::fill(result,result+MAX_ARR_SIZE,C_EMPTY); calculateAreaForPla(P_BLACK,safeBigTerritories,unsafeBigTerritories,isMultiStoneSuicideLegal,result); calculateAreaForPla(P_WHITE,safeBigTerritories,unsafeBigTerritories,isMultiStoneSuicideLegal,result); @@ -1830,6 +1957,8 @@ void Board::calculateIndependentLifeArea( bool keepStones, bool isMultiStoneSuicideLegal ) const { + assert(!isDots()); + //First, just compute basic area. Color basicArea[MAX_ARR_SIZE]; std::fill(result,result+MAX_ARR_SIZE,C_EMPTY); @@ -1886,6 +2015,7 @@ void Board::calculateAreaForPla( bool isMultiStoneSuicideLegal, Color* result ) const { + assert(!isDots()); Color opp = getOpp(pla); //https://senseis.xmp.net/?BensonsAlgorithm @@ -1913,7 +2043,7 @@ void Board::calculateAreaForPla( //A region is vital for a pla group if all its spaces are adjacent to that pla group. //All lists are concatenated together, the most we can have is bounded by (MAX_LEN * MAX_LEN+1) / 2 //independent regions, each one vital for at most 4 pla groups, add some extra just in case. - static constexpr int maxRegions = (MAX_LEN * MAX_LEN + 1)/2 + 1; + static constexpr int maxRegions = (MAX_LEN_X * MAX_LEN_Y + 1)/2 + 1; static constexpr int vitalForPlaHeadsListsMaxLen = maxRegions * 4; Loc vitalForPlaHeadsLists[vitalForPlaHeadsListsMaxLen]; int vitalForPlaHeadsListsTotal = 0; @@ -2319,23 +2449,26 @@ void Board::checkConsistency() const { Hash128 tmp_pos_hash = ZOBRIST_SIZE_X_HASH[x_size] ^ ZOBRIST_SIZE_Y_HASH[y_size]; int emptyCount = 0; for(Loc loc = 0; loc < MAX_ARR_SIZE; loc++) { + const Color color = getColor(loc); int x = Location::getX(loc,x_size); int y = Location::getY(loc,x_size); if(x < 0 || x >= x_size || y < 0 || y >= y_size) { - if(colors[loc] != C_WALL) + if(color != C_WALL) throw StringError(errLabel + "Non-WALL value outside of board legal area"); } else { - if(colors[loc] == C_BLACK || colors[loc] == C_WHITE) { - if(!chainLocChecked[loc]) - checkChainConsistency(loc); - // if(empty_list.contains(loc)) - // throw StringError(errLabel + "Empty list contains filled location"); + if(color == C_BLACK || color == C_WHITE) { + if (!rules.isDots) { + if(!chainLocChecked[loc]) + checkChainConsistency(loc); + // if(empty_list.contains(loc)) + // throw StringError(errLabel + "Empty list contains filled location"); + } - tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][colors[loc]]; + tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][color]; tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][C_EMPTY]; } - else if(colors[loc] == C_EMPTY) { + else if(color== C_EMPTY) { // if(!empty_list.contains(loc)) // throw StringError(errLabel + "Empty list doesn't contain empty location"); emptyCount += 1; @@ -2360,23 +2493,33 @@ void Board::checkConsistency() const { // throw StringError(errLabel + "Empty list index for loc in index i is not i"); // } - if(ko_loc != NULL_LOC) { - int x = Location::getX(ko_loc,x_size); - int y = Location::getY(ko_loc,x_size); - if(x < 0 || x >= x_size || y < 0 || y >= y_size) - throw StringError(errLabel + "Invalid simple ko loc"); - if(getNumImmediateLiberties(ko_loc) != 0) - throw StringError(errLabel + "Simple ko loc has immediate liberties"); + if (!rules.isDots) { + if(ko_loc != NULL_LOC) { + int x = Location::getX(ko_loc,x_size); + int y = Location::getY(ko_loc,x_size); + if(x < 0 || x >= x_size || y < 0 || y >= y_size) + throw StringError(errLabel + "Invalid simple ko loc"); + if(getNumImmediateLiberties(ko_loc) != 0) + throw StringError(errLabel + "Simple ko loc has immediate liberties"); + } } short tmpAdjOffsets[8]; - Location::getAdjacentOffsets(tmpAdjOffsets,x_size); - for(int i = 0; i<8; i++) + Location::getAdjacentOffsets(tmpAdjOffsets, x_size, isDots()); + for(int i = 0; i < 8; i++) if(tmpAdjOffsets[i] != adj_offsets[i]) throw StringError(errLabel + "Corrupted adj_offsets array"); + + for (const bool visited : visited_data) { + if (visited) { + throw StringError("Visited data always should be invalidated"); + } + } } -bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const { +bool Board::isEqualForTesting(const Board& other, const bool checkNumCaptures, + const bool checkSimpleKo, + const bool checkRules) const { checkConsistency(); other.checkConsistency(); if(x_size != other.x_size) @@ -2385,16 +2528,34 @@ bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool ch return false; if(checkSimpleKo && ko_loc != other.ko_loc) return false; - if(checkNumCaptures && numBlackCaptures != other.numBlackCaptures) - return false; - if(checkNumCaptures && numWhiteCaptures != other.numWhiteCaptures) + if(checkNumCaptures && ( + numBlackCaptures != other.numBlackCaptures || + numWhiteCaptures != other.numWhiteCaptures || + blackScoreIfWhiteGrounds != other.blackScoreIfWhiteGrounds || + whiteScoreIfBlackGrounds != other.whiteScoreIfBlackGrounds + )) { return false; + } if(pos_hash != other.pos_hash) return false; for(int i = 0; i 25*25) - return toStringMach(loc,x_size); - if(loc == Board::PASS_LOC) - return string("pass"); - if(loc == Board::NULL_LOC) - return string("null"); - const char* xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; - int x = getX(loc,x_size); - int y = getY(loc,x_size); +string Location::toString(const Loc loc, const int x_size, const int y_size, const bool isDots) { + if (loc == Board::PASS_LOC) + return isDots ? "ground" : "pass"; + if (loc == Board::RESIGN_LOC) + return RESIGN_STR; + if (loc == Board::NULL_LOC) + return "null"; + const int x = getX(loc, x_size); + const int y = getY(loc, x_size); if(x >= x_size || x < 0 || y < 0 || y >= y_size) - return toStringMach(loc,x_size); + return toStringMach(loc, x_size, isDots); - char buf[128]; - if(x <= 24) - sprintf(buf,"%c%d",xChar[x],y_size-y); - else - sprintf(buf,"%c%c%d",xChar[x/25-1],xChar[x%25],y_size-y); - return string(buf); + return (isDots ? (std::to_string(x + 1) + "-") : Global::intToCoord(x)) + std::to_string(y_size - y); } -string Location::toString(Loc loc, const Board& b) { - return toString(loc,b.x_size,b.y_size); +string Location::toString(const Loc loc, const Board& b) { + return toString(loc, b.x_size, b.y_size, b.rules.isDots); } -string Location::toStringMach(Loc loc, const Board& b) { - return toStringMach(loc,b.x_size); -} - -static bool tryParseLetterCoordinate(char c, int& x) { - if(c >= 'A' && c <= 'H') - x = c-'A'; - else if(c >= 'a' && c <= 'h') - x = c-'a'; - else if(c >= 'J' && c <= 'Z') - x = c-'A'-1; - else if(c >= 'j' && c <= 'z') - x = c-'a'-1; - else - return false; - return true; +string Location::toStringMach(const Loc loc, const Board& b) { + return toStringMach(loc, b.x_size, b.isDots()); } -bool Location::tryOfString(const string& str, int x_size, int y_size, Loc& result) { +bool Location::tryOfString(const string& str, const int x_size, const int y_size, Loc& result) { string s = Global::trim(str); if(s.length() < 2) return false; - if(Global::isEqualCaseInsensitive(s,string("pass")) || Global::isEqualCaseInsensitive(s,string("pss"))) { + if( + Global::isEqualCaseInsensitive(s, string("pass")) || Global::isEqualCaseInsensitive(s, string("pss")) || + Global::isEqualCaseInsensitive(s, string("ground"))) { result = Board::PASS_LOC; return true; } + + if (Global::isEqualCaseInsensitive(s, RESIGN_STR)) { + result = Board::RESIGN_LOC; + return true; + } + if(s[0] == '(') { - if(s[s.length()-1] != ')') + if(s[s.length() - 1] != ')') return false; - s = s.substr(1,s.length()-2); - vector pieces = Global::split(s,','); - if(pieces.size() != 2) + + s = s.substr(1, s.length() - 2); + const int commaIndex = s.find(','); + if(commaIndex == std::string::npos) return false; + int x; + if(!Global::tryStringToInt(s.substr(0, commaIndex), x)) + return false; + int y; - bool sucX = Global::tryStringToInt(pieces[0],x); - bool sucY = Global::tryStringToInt(pieces[1],y); - if(!sucX || !sucY) + if(!Global::tryStringToInt(s.substr(commaIndex + 1), y)) return false; - result = Location::getLoc(x,y,x_size); + + result = getLoc(x, y, x_size); return true; } - else { + + if(const auto dashPos = s.find('-'); dashPos != std::string::npos) { int x; - if(!tryParseLetterCoordinate(s[0],x)) + if(!Global::tryStringToInt(s.substr(0, dashPos), x)) return false; - - //Extended format - if((s[1] >= 'A' && s[1] <= 'Z') || (s[1] >= 'a' && s[1] <= 'z')) { - int x1; - if(!tryParseLetterCoordinate(s[1],x1)) - return false; - x = (x+1) * 25 + x1; - s = s.substr(2,s.length()-2); - } - else { - s = s.substr(1,s.length()-1); - } - int y; - bool sucY = Global::tryStringToInt(s,y); - if(!sucY) - return false; - y = y_size - y; - if(x < 0 || y < 0 || x >= x_size || y >= y_size) + if(!Global::tryStringToInt(s.substr(dashPos + 1), y)) return false; - result = Location::getLoc(x,y,x_size); + + result = getLoc(x - 1, y_size - y, x_size); return true; } + + const auto firstDigitIndex = s.find_first_of("0123456789"); + if(firstDigitIndex == std::string::npos) + return false; + + int x; + if(!Global::tryCoordToInt(s.substr(0, firstDigitIndex), x)) + return false; + + int y; + if(!Global::tryStringToInt(s.substr(firstDigitIndex), y)) + return false; + y = y_size - y; + + if(x < 0 || y < 0 || x >= x_size || y >= y_size) + return false; + + result = getLoc(x, y, x_size); + return true; } bool Location::tryOfStringAllowNull(const string& str, int x_size, int y_size, Loc& result) { @@ -2614,55 +2782,66 @@ vector Location::parseSequence(const string& str, const Board& board) { return locs; } -void Board::printBoard(ostream& out, const Board& board, Loc markLoc, const vector* hist) { - if(hist != NULL) +void Board::printBoard( + ostream& out, + const Board& board, + const Loc markLoc, + const vector* hist, + const bool printHash) { + if(hist != nullptr) out << "MoveNum: " << hist->size() << " "; - out << "HASH: " << board.pos_hash << "\n"; - bool showCoords = board.x_size <= 50 && board.y_size <= 50; + if(printHash) { + out << "HASH: " << board.pos_hash; + } + if(hist != nullptr || printHash) { + out << endl; + } + const bool showCoords = board.isDots() || (board.x_size <= 50 && board.y_size <= 50); if(showCoords) { - const char* xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; - out << " "; - for(int x = 0; x < board.x_size; x++) { - if(x <= 24) { - out << " "; - out << xChar[x]; + if(board.isDots()) { + out << " "; + for(int x = 0; x < board.x_size; x++) { + out << std::left << std::setw(2) << std::to_string(x + 1) << ' '; } - else { - out << "A" << xChar[x-25]; + } else { + out << " "; + for(int x = 0; x < board.x_size; x++) { + out << std::right << std::setw(2) << Global::intToCoord(x); } } out << "\n"; } - for(int y = 0; y < board.y_size; y++) - { + for(int y = 0; y < board.y_size; y++) { if(showCoords) { - char buf[16]; - sprintf(buf,"%2d",board.y_size-y); - out << buf << ' '; + out << std::right << std::setw(2) << board.y_size - y << ' '; } - for(int x = 0; x < board.x_size; x++) - { - Loc loc = Location::getLoc(x,y,board.x_size); - char s = PlayerIO::colorToChar(board.colors[loc]); - if(board.colors[loc] == C_EMPTY && markLoc == loc) + for(int x = 0; x < board.x_size; x++) { + const Loc loc = Location::getLoc(x, y, board.x_size); + const State state = board.getState(loc); + const char s = PlayerIO::stateToChar(state, board.rules.isDots); + if(getActiveColor(state) == C_EMPTY && markLoc == loc) out << '@'; else out << s; bool histMarked = false; - if(hist != NULL) { - size_t start = hist->size() >= 3 ? hist->size()-3 : 0; - for(size_t i = 0; start+i < hist->size(); i++) { - if((*hist)[start+i].loc == loc) { - out << (1+i); + if(hist != nullptr) { + size_t start = hist->size() >= 3 ? hist->size() - 3 : 0; + for(size_t i = 0; start + i < hist->size(); i++) { + if((*hist)[start + i].loc == loc) { + out << (1 + i); histMarked = true; break; } } } - if(x < board.x_size-1 && !histMarked) + if(!histMarked && board.isDots()) { + out << ' '; + } + + if(x < board.x_size - 1 && (!histMarked || board.isDots())) out << ' '; } out << "\n"; @@ -2681,19 +2860,16 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { Loc loc = Location::getLoc(x,y,board.x_size); - s += PlayerIO::colorToChar(board.colors[loc]); + s += PlayerIO::colorToChar(board.getColor(loc)); } s += lineDelimiter; } return s; } -Board Board::parseBoard(int xSize, int ySize, const string& s) { - return parseBoard(xSize,ySize,s,'\n'); -} - -Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimiter) { - Board board(xSize,ySize); +Board Board::parseBoard(const int xSize, + const int ySize, const string& s, const Rules& rules, const char lineDelimiter) { + Board board(xSize,ySize,rules); vector lines = Global::split(Global::trim(s),lineDelimiter); //Throw away coordinate labels line if it exists @@ -2742,24 +2918,42 @@ Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimite return board; } +std::string Board::toString() const { + std::ostringstream oss; + printBoard(oss, *this, NULL_LOC, nullptr); + return oss.str(); +} + nlohmann::json Board::toJson(const Board& board) { nlohmann::json data; + if (board.rules.isDots) { + data[DOTS_KEY] = true; + } data["xSize"] = board.x_size; data["ySize"] = board.y_size; - data["stones"] = Board::toStringSimple(board,'|'); - data["koLoc"] = Location::toString(board.ko_loc,board); + data["stones"] = toStringSimple(board,'|'); + if (!board.isDots()) { + data["koLoc"] = Location::toString(board.ko_loc,board); + } data["numBlackCaptures"] = board.numBlackCaptures; data["numWhiteCaptures"] = board.numWhiteCaptures; + if (board.isDots()) { + data[BLACK_SCORE_IF_WHITE_GROUNDS_KEY] = board.blackScoreIfWhiteGrounds; + data[WHITE_SCORE_IF_BLACK_GROUNDS_KEY] = board.whiteScoreIfBlackGrounds; + } return data; } Board Board::ofJson(const nlohmann::json& data) { + bool dots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - Board board = Board::parseBoard(xSize,ySize,data["stones"].get(),'|'); - board.setSimpleKoLoc(Location::ofStringAllowNull(data["koLoc"].get(),board)); + Board board = parseBoard(xSize, ySize, data["stones"].get(), Rules(dots), '|'); + board.setSimpleKoLoc(Location::ofStringAllowNull(data.value("koLoc", "null"),board)); board.numBlackCaptures = data["numBlackCaptures"].get(); board.numWhiteCaptures = data["numWhiteCaptures"].get(); + board.blackScoreIfWhiteGrounds = data.value(BLACK_SCORE_IF_WHITE_GROUNDS_KEY, 0); + board.whiteScoreIfBlackGrounds = data.value(WHITE_SCORE_IF_BLACK_GROUNDS_KEY, 0); return board; } @@ -2805,9 +2999,13 @@ bool Board::countEmptyHelper(bool* emptyCounted, Loc initialLoc, int& count, int } bool Board::simpleRepetitionBoundGt(Loc loc, int bound) const { - if(loc == NULL_LOC || loc == PASS_LOC) + if(loc == NULL_LOC || loc == PASS_LOC || loc == RESIGN_LOC) return false; + if (rules.isDots) { + return false; // TODO: implement for Dots? + } + int count = 0; if(colors[loc] != C_EMPTY) { @@ -2840,3 +3038,14 @@ bool Board::simpleRepetitionBoundGt(Loc loc, int bound) const { return false; } + +Board::MoveRecord::MoveRecord(const Loc initLoc, const Player initPla, const Loc init_ko_loc, const uint8_t initCapDirs) { + loc = initLoc; + pla = initPla; + ko_loc = init_ko_loc; + capDirs = initCapDirs; + + previousState = C_EMPTY; + bases = {}; + emptyBaseInvalidateLocations = {}; +} \ No newline at end of file diff --git a/cpp/game/board.h b/cpp/game/board.h index 4fdb2a259..fccdfc162 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -10,50 +10,63 @@ #include "../core/global.h" #include "../core/hash.h" #include "../external/nlohmann_json/json.hpp" +#include "rules.h" +#include "../core/rand.h" #ifndef COMPILE_MAX_BOARD_LEN -#define COMPILE_MAX_BOARD_LEN 19 +#define COMPILE_MAX_BOARD_LEN 39 #endif +#define FOREACHADJ(BLOCK) {int ADJOFFSET = -(x_size+1); {BLOCK}; ADJOFFSET = -1; {BLOCK}; ADJOFFSET = 1; {BLOCK}; ADJOFFSET = x_size+1; {BLOCK}}; + //TYPES AND CONSTANTS----------------------------------------------------------------- +static constexpr int LEFT_TOP_INDEX = 0; +static constexpr int TOP_INDEX = 1; +static constexpr int RIGHT_TOP_INDEX = 2; +static constexpr int RIGHT_INDEX = 3; +static constexpr int RIGHT_BOTTOM_INDEX = 4; +static constexpr int BOTTOM_INDEX = 5; +static constexpr int LEFT_BOTTOM_INDEX = 6; +static constexpr int LEFT_INDEX = 7; + struct Board; -//Player -typedef int8_t Player; -static constexpr Player P_BLACK = 1; -static constexpr Player P_WHITE = 2; +typedef int8_t State; -//Color of a point on the board -typedef int8_t Color; -static constexpr Color C_EMPTY = 0; -static constexpr Color C_BLACK = 1; -static constexpr Color C_WHITE = 2; -static constexpr Color C_WALL = 3; -static constexpr int NUM_BOARD_COLORS = 4; +static constexpr int PLAYER_BITS_COUNT = 2; +static constexpr State ACTIVE_MASK = (1 << PLAYER_BITS_COUNT) - 1; static inline Color getOpp(Color c) {return c ^ 3;} +Color getActiveColor(State state); + +Color getPlacedDotColor(State s); + +Color getEmptyTerritoryColor(State s); + +bool isGrounded(State state); + +bool isTerritory(State s); + //Conversions for players and colors namespace PlayerIO { char colorToChar(Color c); - std::string playerToStringShort(Player p); - std::string playerToString(Player p); + char stateToChar(State s, bool isDots); + std::string playerToStringShort(Color p, bool isDots = false); + std::string playerToString(Player p, bool isDots); bool tryParsePlayer(const std::string& s, Player& pla); Player parsePlayer(const std::string& s); } -//Location of a point on the board -//(x,y) is represented as (x+1) + (y+1)*(x_size+1) -typedef short Loc; namespace Location { Loc getLoc(int x, int y, int x_size); int getX(Loc loc, int x_size); int getY(Loc loc, int x_size); - void getAdjacentOffsets(short adj_offsets[8], int x_size); + void getAdjacentOffsets(short adj_offsets[8], int x_size, bool isDots); bool isAdjacent(Loc loc0, Loc loc1, int x_size); Loc getMirrorLoc(Loc loc, int x_size, int y_size); Loc getCenterLoc(int x_size, int y_size); @@ -62,10 +75,12 @@ namespace Location bool isNearCentral(Loc loc, int x_size, int y_size); int distance(Loc loc0, Loc loc1, int x_size); int euclideanDistanceSquared(Loc loc0, Loc loc1, int x_size); + int getGetBigJumpInitialIndex(Loc loc0, Loc loc1, int x_size); + Loc getNextLocCW(Loc loc0, Loc loc1, int x_size); - std::string toString(Loc loc, int x_size, int y_size); + std::string toString(Loc loc, int x_size, int y_size, bool isDots); std::string toString(Loc loc, const Board& b); - std::string toStringMach(Loc loc, int x_size); + std::string toStringMach(Loc loc, int x_size, bool isDots); std::string toStringMach(Loc loc, const Board& b); bool tryOfString(const std::string& str, int x_size, int y_size, Loc& result); @@ -80,15 +95,25 @@ namespace Location Loc ofStringAllowNull(const std::string& str, const Board& b); std::vector parseSequence(const std::string& str, const Board& b); -} -//Simple structure for storing moves. Not used below, but this is a convenient place to define it. -STRUCT_NAMED_PAIR(Loc,loc,Player,pla,Move); + Loc xm1y(Loc loc); + Loc xm1ym1(Loc loc, int x_size); + Loc xym1(Loc loc, int x_size); + Loc xp1ym1(Loc loc, int x_size); + Loc xp1y(Loc loc); + Loc xp1yp1(Loc loc, int x_size); + Loc xyp1(Loc loc, int x_size); + Loc xm1yp1(Loc loc, int x_size); +} //Fast lightweight board designed for playouts and simulations, where speed is essential. //Simple ko rule only. //Does not enforce player turn order. +constexpr static int getMaxArrSize(const int x_size, const int y_size) { + return (x_size+1)*(y_size+2)+1; +} + struct Board { //Initialization------------------------------ @@ -98,20 +123,37 @@ struct Board //Board parameters and Constants---------------------------------------- - static constexpr int MAX_LEN = COMPILE_MAX_BOARD_LEN; //Maximum edge length allowed for the board - static constexpr int DEFAULT_LEN = std::min(MAX_LEN,19); //Default edge length for board if unspecified - static constexpr int MAX_PLAY_SIZE = MAX_LEN * MAX_LEN; //Maximum number of playable spaces - static constexpr int MAX_ARR_SIZE = (MAX_LEN+1)*(MAX_LEN+2)+1; //Maximum size of arrays needed - - //Location used to indicate an invalid spot on the board. + static constexpr int MAX_LEN_X = //Maximum x edge length allowed for the board +#ifdef COMPILE_MAX_BOARD_LEN_X + COMPILE_MAX_BOARD_LEN_X; +#else + COMPILE_MAX_BOARD_LEN; +#endif + static constexpr int MAX_LEN_Y = //Maximum y edge length allowed for the board +#ifdef COMPILE_MAX_BOARD_LEN_Y + COMPILE_MAX_BOARD_LEN_Y; +#else + COMPILE_MAX_BOARD_LEN; +#endif + static constexpr int MAX_LEN = std::max(MAX_LEN_X, MAX_LEN_Y); //Maximum edge length allowed for the board + static constexpr int DEFAULT_LEN_X = std::min(MAX_LEN_X,19); //Default x edge length for board if unspecified + static constexpr int DEFAULT_LEN_Y = std::min(MAX_LEN_Y,19); //Default y edge length for board if unspecified + static constexpr int DEFAULT_LEN_X_DOTS = std::min(MAX_LEN_X, 39); + static constexpr int DEFAULT_LEN_Y_DOTS = std::min(MAX_LEN_Y, 32); + static constexpr int MAX_PLAY_SIZE = MAX_LEN_X * MAX_LEN_Y; //Maximum number of playable spaces + static constexpr int MAX_ARR_SIZE = getMaxArrSize(MAX_LEN_X, MAX_LEN_Y); //Maximum size of arrays needed + + // Location used to indicate an invalid spot on the board. static constexpr Loc NULL_LOC = 0; - //Location used to indicate a pass move is desired. + // Location used to indicate a pass or grounding (Dots game) move is desired. static constexpr Loc PASS_LOC = 1; + // Location used to indicate resigning move. + static constexpr Loc RESIGN_LOC = 2; //Zobrist Hashing------------------------------ static bool IS_ZOBRIST_INITALIZED; - static Hash128 ZOBRIST_SIZE_X_HASH[MAX_LEN+1]; - static Hash128 ZOBRIST_SIZE_Y_HASH[MAX_LEN+1]; + static Hash128 ZOBRIST_SIZE_X_HASH[MAX_LEN_X+1]; + static Hash128 ZOBRIST_SIZE_Y_HASH[MAX_LEN_Y+1]; static Hash128 ZOBRIST_BOARD_HASH[MAX_ARR_SIZE][4]; static Hash128 ZOBRIST_BOARD_HASH2[MAX_ARR_SIZE][4]; static Hash128 ZOBRIST_PLAYER_HASH[4]; @@ -147,23 +189,74 @@ struct Board /* int size_; */ /* }; */ + struct Base { + std::vector rollback_locations; + std::vector rollback_states; + Player pla{}; + bool is_real{}; + + Base() = default; + Base(Player newPla, const std::vector& rollbackLocations, const std::vector& rollbackStates, bool isReal); + }; + //Move data passed back when moves are made to allow for undos struct MoveRecord { Player pla; Loc loc; Loc ko_loc; uint8_t capDirs; //First 4 bits indicate directions of capture, fifth bit indicates suicide + + // Move data for Dots game + State previousState; + std::vector bases; + std::vector emptyBaseInvalidateLocations; + std::vector groundingLocations; + + MoveRecord() = default; + + // Constructor for Go game + MoveRecord( + Loc initLoc, + Player initPla, + Loc init_ko_loc, + uint8_t initCapDirs + ); + + // Constructor for Dots game + MoveRecord( + Loc newLoc, + Player newPla, + State newPreviousState, + const std::vector& newBases, + const std::vector& newEmptyBaseInvalidateLocations, + const std::vector& newGroundingLocations + ); }; //Constructors--------------------------------- Board(); //Create Board of size (DEFAULT_LEN,DEFAULT_LEN) - Board(int x, int y); //Create Board of size (x,y) + explicit Board(const Rules& newRules); + Board(int x, int y, const Rules& newRules); // Create Board of size (x,y) with the specified Rules Board(const Board& other); Board& operator=(const Board&) = default; //Functions------------------------------------ + [[nodiscard]] Color getColor(Loc loc) const; + [[nodiscard]] Color getPlacedColor(Loc loc) const; + [[nodiscard]] State getState(Loc loc) const; + void setState(Loc loc, State state); + bool isDots() const; + + template void forEachAdjacent(const Loc loc, Func&& f) const { + const int stride = x_size + 1; + f(loc - stride); + f(loc - 1); + f(loc + 1); + f(loc + stride); + } + double sqrtBoardArea() const; //Gets the number of stones of the chain at loc. Precondition: location must be black or white. @@ -183,10 +276,8 @@ struct Board bool isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; //Check if moving here is illegal due to simple ko bool isKoBanned(Loc loc) const; - //Check if moving here is legal, ignoring simple ko - bool isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; //Check if moving here is legal. Equivalent to isLegalIgnoringKo && !isKoBanned - bool isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; + bool isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal, bool ignoreKo) const; //Check if this location is on the board bool isOnBoard(Loc loc) const; //Check if this location contains a simple eye for the specified player. @@ -203,11 +294,13 @@ struct Board bool isAdjacentToChain(Loc loc, Loc chain) const; //Does this connect two pla distinct groups that are not both pass-alive and not within opponent pass-alive area either? bool isNonPassAliveSelfConnection(Loc loc, Player pla, Color* passAliveArea) const; - //Is this board empty? - bool isEmpty() const; + // Is this board empty? + bool isStartPos() const; //Count the number of stones on the board int numStonesOnBoard() const; int numPlaStonesOnBoard(Player pla) const; + std::vector + getCurrentMoves(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, bool includeStartLocs) const; //Get a hash that combines the position of the board with simple ko prohibition and a player to move. Hash128 getSitHashWithSimpleKo(Player pla) const; @@ -223,15 +316,20 @@ struct Board //Returns false if location or color were out of range. bool setStone(Loc loc, Color color); + // Set the start pos and use the provided random in case of randomization is used + // It should be called strictly before handicap placement + void setStartPos(Rand& rand); //Sets the specified stone, including overwriting existing stones, but only if doing so will //not result in any captures or zero liberty groups. //Returns false if location or color were out of range, or if would cause a zero liberty group. //In case of failure, will restore the position, but may result in chain ids or ordering in the board changing. - bool setStoneFailIfNoLibs(Loc loc, Color color); + //If startPos is true, adds the move to start pos moves to distinguish between start pos and handicap stones + bool setStoneFailIfNoLibs(Loc loc, Color color, bool startPos = false); //Same, but sets multiple stones, and only requires that the final configuration contain no zero-liberty groups. //If it does contain a zero liberty group, fails and returns false and leaves the board in an arbitrarily changed but valid state. //Also returns false if any location is specified more than once. - bool setStonesFailIfNoLibs(std::vector placements); + //If startPos is true, adds the placements to start pos moves to distinguish between start pos and handicap stones + bool setStonesFailIfNoLibs(const std::vector& placements, bool startPos = false); //Attempts to play the specified move. Returns true if successful, returns false if the move was illegal. bool playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal); @@ -239,13 +337,13 @@ struct Board //Plays the specified move, assuming it is legal. void playMoveAssumeLegal(Loc loc, Player pla); - //Plays the specified move, assuming it is legal, and returns a MoveRecord for the move + // Plays the specified move, assuming it is legal, and returns a MoveRecord for the move MoveRecord playMoveRecorded(Loc loc, Player pla); //Undo the move given by record. Moves MUST be undone in the order they were made. //Undos will NOT typically restore the precise representation in the board to the way it was. The heads of chains //might change, the order of the circular lists might change, etc. - void undo(MoveRecord record); + void undo(MoveRecord& record); //Get what the position hash would be if we were to play this move and resolve captures and suicides. //Assumes the move is on an empty location. @@ -270,6 +368,7 @@ struct Board //If unsafeBigTerritories, also marks for each pla empty regions bordered by pla stones and no opp stones, regardless. //All other points are marked as C_EMPTY. //[result] must be a buffer of size MAX_ARR_SIZE and will get filled with the result + // For Dots game it just calculates grounding void calculateArea( Color* result, bool nonPassAliveStones, @@ -278,8 +377,9 @@ struct Board bool isMultiStoneSuicideLegal ) const; + int calculateOwnershipAndWhiteScore(Color* result, Color groundingPlayer) const; - //Calculates the area (including non pass alive stones, safe and unsafe big territories) + // Calculates the area (including non pass alive stones, safe and unsafe big territories) //However, strips out any "seki" regions. //Seki regions are that are adjacent to any remaining empty regions. //If keepTerritories, then keeps the surrounded territories in seki regions, only strips points for stones. @@ -293,16 +393,18 @@ struct Board bool isMultiStoneSuicideLegal ) const; + void calculateOneMoveCaptureAndBasePositionsForDots(std::vector& captures, std::vector& bases) const; + //Run some basic sanity checks on the board state, throws an exception if not consistent, for testing/debugging void checkConsistency() const; //For the moment, only used in testing since it does extra consistency checks. //If we need a version to be used in "prod", we could make an efficient version maybe as operator==. - bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const; + bool isEqualForTesting(const Board& other, bool checkNumCaptures = true, bool checkSimpleKo = true, bool checkRules = true) const; - static Board parseBoard(int xSize, int ySize, const std::string& s); - static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter); - static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); - static std::string toStringSimple(const Board& board, char lineDelimiter); + static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules = Rules::DEFAULT_GO, char lineDelimiter = '\n'); + std::string toString() const; + static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist, bool printHash = true); + static std::string toStringSimple(const Board& board, char lineDelimiter = '\n'); static nlohmann::json toJson(const Board& board); static Board ofJson(const nlohmann::json& data); @@ -310,13 +412,9 @@ struct Board int x_size; //Horizontal size of board int y_size; //Vertical size of board + Rules rules; Color colors[MAX_ARR_SIZE]; //Color of each location on the board. - //Every chain of stones has one of its stones arbitrarily designated as the head. - ChainData chain_data[MAX_ARR_SIZE]; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. - Loc chain_head[MAX_ARR_SIZE]; //Where is the head of this chain? Undefined if EMPTY or WALL - Loc next_in_chain[MAX_ARR_SIZE]; //Location of next stone in chain. Circular linked list. Undefined if EMPTY or WALL - Loc ko_loc; //A simple ko capture was made here, making it illegal to replay here next move /* PointList empty_list; //List of all empty locations on board */ @@ -326,10 +424,74 @@ struct Board int numBlackCaptures; //Number of b stones captured, informational and used by board history when clearing pos int numWhiteCaptures; //Number of w stones captured, informational and used by board history when clearing pos - short adj_offsets[8]; //Indices 0-3: Offsets to add for adjacent points. Indices 4-7: Offsets for diagonal points. 2 and 3 are +x and +y. + // Useful for fast calculation of the game result and finishing the game + int blackScoreIfWhiteGrounds; + int whiteScoreIfBlackGrounds; + + // Offsets to add to get clockwise traverse + short adj_offsets[8]; + + int numLegalMovesIfSuiAllowed; + + //Every chain of stones has one of its stones arbitrarily designated as the head. + std::vector chain_data; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. + std::vector chain_head; //Where is the head of this chain? Undefined if EMPTY or WALL + std::vector next_in_chain; //Location of next stone in chain. Circular linked list. Undefined if EMPTY or WALL + std::vector start_pos_moves; //Moves that are played at the very beginning of the game private: - void init(int xS, int yS); + + // Dots game data + mutable std::vector closureOrInvalidateLocsBuffer = std::vector(); + mutable std::vector territoryLocationsBuffer = std::vector(); + mutable std::vector walkStack = std::vector(); + mutable std::vector visited_data = std::vector(); + + // Dots game functions + [[nodiscard]] bool wouldBeCaptureDots(Loc loc, Player pla) const; + [[nodiscard]] bool isSuicideDots(Loc loc, Player pla) const; + void playMoveAssumeLegalDots(Loc loc, Player pla); + MoveRecord playMoveRecordedDots(Loc loc, Player pla); + MoveRecord tryPlayMoveRecordedDots(Loc loc, Player pla, bool isSuicideLegal); + void undoDots(MoveRecord& moveRecord); + std::vector fillGrounding(Loc loc); + void captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp, std::vector& bases, bool& isGrounded); + void tryCapture( + Loc loc, + Player pla, + const std::array& unconnectedLocations, + int unconnectedLocationsSize, + bool& atLeastOneRealBaseIsGrounded, + std::vector& bases); + void ground(Player pla, std::vector& emptyBaseInvalidatePositions, std::vector& bases); + std::array getUnconnectedLocations(Loc loc, Player pla, int& size) const; + void checkAndAddUnconnectedLocation( + std::array& unconnectedLocationsBuffer, + int& size, + Player checkPla, + Player currentPla, + Loc addLoc1, + Loc addLoc2) const; + void tryGetCounterClockwiseClosure(Loc initialLoc, Loc startLoc, Player pla) const; + Base buildBase(const std::vector& closure, Player pla); + void getTerritoryLocations(Player pla, Loc firstLoc, bool grounding, bool& createRealBase) const; + Base createBaseAndUpdateStates(Player basePla, bool isReal); + void updateScoreAndHashForTerritory(Loc loc, State state, Player basePla, bool rollback); + void invalidateAdjacentEmptyTerritoryIfNeeded(Loc loc); + void makeMoveAndCalculateCapturesAndBases( + Player pla, + Loc loc, + std::vector& captures, + std::vector& bases) const; + + void setGrounded(Loc loc); + void clearGrounded(Loc loc); + bool isVisited(Loc loc) const; + void setVisited(Loc loc) const; + void clearVisited(Loc loc) const; + void clearVisited(const std::vector& locations) const; + + void init(int xS, int yS, const Rules& initRules); int countHeuristicConnectionLibertiesX2(Loc loc, Player pla) const; bool isLibertyOf(Loc loc, Loc head) const; void mergeChains(Loc loc1, Loc loc2); diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 5b3c35a9b..6a638f071 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -23,54 +23,66 @@ static Hash128 getKoHashAfterMoveNonEncore(const Rules& rules, Hash128 posHashAf // return posHashAfterMove ^ koRecapBlockHashAfterMove; // } - -BoardHistory::BoardHistory() - :rules(), - moveHistory(), - preventEncoreHistory(), - koHashHistory(), - firstTurnIdxWithKoHistory(0), - initialBoard(), - initialPla(P_BLACK), - initialEncorePhase(0), - initialTurnNumber(0), - assumeMultipleStartingBlackMovesAreHandicap(false), - whiteHasMoved(false), - overrideNumHandicapStones(-1), - recentBoards(), - currentRecentBoardIdx(0), - presumedNextMovePla(P_BLACK), - consecutiveEndingPasses(0), - hashesBeforeBlackPass(),hashesBeforeWhitePass(), - encorePhase(0), - numTurnsThisPhase(0), - numApproxValidTurnsThisPhase(0), - numConsecValidTurnsThisGame(0), - koRecapBlockHash(), - koCapturesInEncore(), - whiteBonusScore(0.0f), - whiteHandicapBonusScore(0.0f), - hasButton(false), - isPastNormalPhaseEnd(false), - isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), - isScored(false),isNoResult(false),isResignation(false) -{ - std::fill(wasEverOccupiedOrPlayed, wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, false); - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); +BoardHistory::BoardHistory() : BoardHistory(Rules::DEFAULT_GO) {} + +BoardHistory::BoardHistory(const Rules& newRules) + : rules(newRules), + moveHistory(), + preventEncoreHistory(), + koHashHistory(), + firstTurnIdxWithKoHistory(0), + initialBoard(newRules), + initialPla(P_BLACK), + initialEncorePhase(0), + initialTurnNumber(0), + assumeMultipleStartingBlackMovesAreHandicap(false), + whiteHasMoved(false), + overrideNumHandicapStones(-1), + currentRecentBoardIdx(0), + presumedNextMovePla(P_BLACK), + consecutiveEndingPasses(0), + hashesBeforeBlackPass(), + hashesBeforeWhitePass(), + encorePhase(0), + numTurnsThisPhase(0), + numApproxValidTurnsThisPhase(0), + numConsecValidTurnsThisGame(0), + koRecapBlockHash(), + koCapturesInEncore(), + whiteBonusScore(0.0f), + whiteHandicapBonusScore(0.0f), + hasButton(false), + isPastNormalPhaseEnd(false), + isGameFinished(false), + winner(C_EMPTY), + finalWhiteMinusBlackScore(0.0f), + isScored(false), + isNoResult(false), + isResignation(false), + isPassAliveFinished(false) { + for(int i = 0; i < NUM_RECENT_BOARDS; i++) { + recentBoards.emplace_back(newRules); + } + if(!newRules.isDots) { + wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); + superKoBanned.resize(Board::MAX_ARR_SIZE, false); + koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); + secondEncoreStartColors.resize(Board::MAX_ARR_SIZE, C_EMPTY); + } } BoardHistory::~BoardHistory() {} +BoardHistory::BoardHistory(const Board& board) : BoardHistory(board, P_BLACK, board.rules, 0) {} + BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int ePhase) :rules(r), moveHistory(), preventEncoreHistory(), koHashHistory(), firstTurnIdxWithKoHistory(0), - initialBoard(), + initialBoard(rules), initialPla(), initialEncorePhase(0), initialTurnNumber(0), @@ -93,12 +105,17 @@ BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int e hasButton(false), isPastNormalPhaseEnd(false), isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), - isScored(false),isNoResult(false),isResignation(false) + isScored(false),isNoResult(false),isResignation(false),isPassAliveFinished(false) { - std::fill(wasEverOccupiedOrPlayed, wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, false); - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); + for(int i = 0; i < NUM_RECENT_BOARDS; i++) { + recentBoards.emplace_back(rules); + } + if (!rules.isDots) { + wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); + superKoBanned.resize(Board::MAX_ARR_SIZE, false); + koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); + secondEncoreStartColors.resize(Board::MAX_ARR_SIZE, C_EMPTY); + } clear(board,pla,rules,ePhase); } @@ -116,7 +133,6 @@ BoardHistory::BoardHistory(const BoardHistory& other) assumeMultipleStartingBlackMovesAreHandicap(other.assumeMultipleStartingBlackMovesAreHandicap), whiteHasMoved(other.whiteHasMoved), overrideNumHandicapStones(other.overrideNumHandicapStones), - recentBoards(), currentRecentBoardIdx(other.currentRecentBoardIdx), presumedNextMovePla(other.presumedNextMovePla), consecutiveEndingPasses(other.consecutiveEndingPasses), @@ -132,13 +148,13 @@ BoardHistory::BoardHistory(const BoardHistory& other) hasButton(other.hasButton), isPastNormalPhaseEnd(other.isPastNormalPhaseEnd), isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), - isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) + isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation),isPassAliveFinished(other.isPassAliveFinished) { - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + recentBoards = other.recentBoards; + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; + koRecapBlocked = other.koRecapBlocked; + secondEncoreStartColors = other.secondEncoreStartColors; } @@ -158,11 +174,11 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; consecutiveEndingPasses = other.consecutiveEndingPasses; hashesBeforeBlackPass = other.hashesBeforeBlackPass; hashesBeforeWhitePass = other.hashesBeforeWhitePass; @@ -170,10 +186,10 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) numTurnsThisPhase = other.numTurnsThisPhase; numApproxValidTurnsThisPhase = other.numApproxValidTurnsThisPhase; numConsecValidTurnsThisGame = other.numConsecValidTurnsThisGame; - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); + koRecapBlocked = other.koRecapBlocked; koRecapBlockHash = other.koRecapBlockHash; koCapturesInEncore = other.koCapturesInEncore; - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + secondEncoreStartColors = other.secondEncoreStartColors; whiteBonusScore = other.whiteBonusScore; whiteHandicapBonusScore = other.whiteHandicapBonusScore; hasButton = other.hasButton; @@ -184,6 +200,7 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) isScored = other.isScored; isNoResult = other.isNoResult; isResignation = other.isResignation; + isPassAliveFinished = other.isPassAliveFinished; return *this; } @@ -201,7 +218,6 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept assumeMultipleStartingBlackMovesAreHandicap(other.assumeMultipleStartingBlackMovesAreHandicap), whiteHasMoved(other.whiteHasMoved), overrideNumHandicapStones(other.overrideNumHandicapStones), - recentBoards(), currentRecentBoardIdx(other.currentRecentBoardIdx), presumedNextMovePla(other.presumedNextMovePla), consecutiveEndingPasses(other.consecutiveEndingPasses), @@ -217,13 +233,13 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept hasButton(other.hasButton), isPastNormalPhaseEnd(other.isPastNormalPhaseEnd), isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), - isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) + isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation),isPassAliveFinished(other.isPassAliveFinished) { - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + recentBoards = other.recentBoards; + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; + koRecapBlocked = other.koRecapBlocked; + secondEncoreStartColors = other.secondEncoreStartColors; } BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept @@ -240,11 +256,11 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; consecutiveEndingPasses = other.consecutiveEndingPasses; hashesBeforeBlackPass = std::move(other.hashesBeforeBlackPass); hashesBeforeWhitePass = std::move(other.hashesBeforeWhitePass); @@ -252,10 +268,10 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept numTurnsThisPhase = other.numTurnsThisPhase; numApproxValidTurnsThisPhase = other.numApproxValidTurnsThisPhase; numConsecValidTurnsThisGame = other.numConsecValidTurnsThisGame; - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); + koRecapBlocked = other.koRecapBlocked; koRecapBlockHash = other.koRecapBlockHash; koCapturesInEncore = std::move(other.koCapturesInEncore); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + secondEncoreStartColors = other.secondEncoreStartColors; whiteBonusScore = other.whiteBonusScore; whiteHandicapBonusScore = other.whiteHandicapBonusScore; hasButton = other.hasButton; @@ -266,6 +282,7 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept isScored = other.isScored; isNoResult = other.isNoResult; isResignation = other.isResignation; + isPassAliveFinished = other.isPassAliveFinished; return *this; } @@ -292,22 +309,12 @@ void BoardHistory::clear(const Board& board, Player pla, const Rules& r, int ePh currentRecentBoardIdx = 0; presumedNextMovePla = pla; - - for(int y = 0; y= 0 && encorePhase <= 2); - if(encorePhase > 0) - assert(rules.scoringRule == Rules::SCORING_TERRITORY); - //Update the few parameters that depend on encore - if(encorePhase == 2) - std::copy(board.colors, board.colors+Board::MAX_ARR_SIZE, secondEncoreStartColors); - else - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); - - //Push hash for the new board state - koHashHistory.push_back(getKoHash(rules,board,pla,encorePhase,koRecapBlockHash)); - - if(rules.scoringRule == Rules::SCORING_TERRITORY) { - //Chill 1 point for every move played + if (!rules.isDots) { for(int y = 0; y= 0 && encorePhase <= 2); + if(encorePhase > 0) + assert(rules.scoringRule == Rules::SCORING_TERRITORY); + //Update the few parameters that depend on encore + if(encorePhase == 2) + std::copy_n(board.colors, Board::MAX_ARR_SIZE, secondEncoreStartColors.begin()); + else + std::fill(secondEncoreStartColors.begin(), secondEncoreStartColors.end(), C_EMPTY); + + //Push hash for the new board state + koHashHistory.push_back(getKoHash(rules,board,pla,encorePhase,koRecapBlockHash)); + + if(rules.scoringRule == Rules::SCORING_TERRITORY) { + //Chill 1 point for every move played + for(int y = 0; y(computeWhiteHandicapBonus()); } void BoardHistory::setOverrideNumHandicapStones(int n) { overrideNumHandicapStones = n; - whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); + whiteHandicapBonusScore = static_cast(computeWhiteHandicapBonus()); } -static int numHandicapStonesOnBoardHelper(const Board& board, int blackNonPassTurnsToStart) { - int startBoardNumBlackStones = 0; - int startBoardNumWhiteStones = 0; - for(int y = 0; y 0) out << "Game phase: " << encorePhase << endl; out << "Rules: " << rules.toJsonString() << endl; if(whiteHandicapBonusScore != 0) out << "Handicap bonus score: " << whiteHandicapBonusScore << endl; - out << "B stones captured: " << board.numBlackCaptures << endl; - out << "W stones captured: " << board.numWhiteCaptures << endl; + + const auto firstPlayerName = PlayerIO::playerToString(P_BLACK, isDots); + const auto secondPlayerName = PlayerIO::playerToString(P_WHITE, isDots); + if (!isDots) { + out << firstPlayerName << " stones captured: " << board.numBlackCaptures << endl; + out << secondPlayerName << " stones captured: " << board.numWhiteCaptures << endl; + } else { + out << firstPlayerName << " score: " << board.numWhiteCaptures << endl; + out << secondPlayerName << " score: " << board.numBlackCaptures << endl; + } + if (isGameFinished) { + out << "Game is finished, winner: " << PlayerIO::playerToString(winner, isDots) << ", score: " << finalWhiteMinusBlackScore << ", resign: " << boolalpha << isResignation << endl; + } } void BoardHistory::printDebugInfo(ostream& out, const Board& board) const { out << board << endl; - out << "Initial pla " << PlayerIO::playerToString(initialPla) << endl; - out << "Encore phase " << encorePhase << endl; - out << "Turns this phase " << numTurnsThisPhase << endl; - out << "Approx valid turns this phase " << numApproxValidTurnsThisPhase << endl; - out << "Approx consec valid turns this game " << numConsecValidTurnsThisGame << endl; + const bool isDots = board.rules.isDots; + if (!isDots) { + out << "Initial pla " << PlayerIO::playerToString(initialPla, rules.isDots) << endl; + out << "Encore phase " << encorePhase << endl; + out << "Turns this phase " << numTurnsThisPhase << endl; + out << "Approx valid turns this phase " << numApproxValidTurnsThisPhase << endl; + out << "Approx consec valid turns this game " << numConsecValidTurnsThisGame << endl; + } else { + assert(0 == encorePhase); + } out << "Rules " << rules << endl; - out << "Ko recap block hash " << koRecapBlockHash << endl; + if (!isDots) { + out << "Ko recap block hash " << koRecapBlockHash << endl; + } else { + assert(Hash128() == koRecapBlockHash); + } out << "White bonus score " << whiteBonusScore << endl; - out << "White handicap bonus score " << whiteHandicapBonusScore << endl; - out << "Has button " << hasButton << endl; - out << "Presumed next pla " << PlayerIO::playerToString(presumedNextMovePla) << endl; - out << "Past normal phase end " << isPastNormalPhaseEnd << endl; - out << "Game result " << isGameFinished << " " << PlayerIO::playerToString(winner) << " " - << finalWhiteMinusBlackScore << " " << isScored << " " << isNoResult << " " << isResignation << endl; + if (!isDots) { + out << "White handicap bonus score " << whiteHandicapBonusScore << endl; + out << "Has button " << hasButton << endl; + } else { + assert(0.0f == whiteHandicapBonusScore); + assert(!hasButton); + } + out << "Presumed next pla " << PlayerIO::playerToString(presumedNextMovePla, rules.isDots) << endl; + if (!isDots) { + out << "Past normal phase end " << isPastNormalPhaseEnd << endl; + } else { + assert(0 == isPastNormalPhaseEnd); + } + out << "Game result " << isGameFinished << " " << PlayerIO::playerToString(winner, rules.isDots) << " " + << finalWhiteMinusBlackScore << " " << isScored << " " << isNoResult << " " << isResignation; + if (isPassAliveFinished) { + out << " " << isPassAliveFinished; + } + out << endl; out << "Last moves "; - for(int i = 0; isize() moves of koHashHistory. //ALSO counts the most recent ko hash! bool BoardHistory::koHashOccursInHistory(Hash128 koHash, const KoHashTable* rootKoHashTable) const { + assert(!rules.isDots); + size_t start = 0; size_t koHashHistorySize = koHashHistory.size(); if(rootKoHashTable != NULL && @@ -574,6 +627,8 @@ float BoardHistory::currentSelfKomi(Player pla, double drawEquivalentWinsForWhit } int BoardHistory::countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots == board.isDots() && !rules.isDots); + int score = 0; if(rules.taxRule == Rules::TAX_NONE) { bool nonPassAliveStones = true; @@ -615,6 +670,8 @@ int BoardHistory::countAreaScoreWhiteMinusBlack(const Board& board, Color area[B //ALSO makes area color the points that were not pass alive but were scored for a side. int BoardHistory::countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots == board.isDots() && !rules.isDots); + int score = 0; bool keepTerritories; bool keepStones; @@ -664,28 +721,23 @@ int BoardHistory::countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Col return score; } -void BoardHistory::setFinalScoreAndWinner(float score) { +void BoardHistory::setFinalScoreAndWinner(const float score) { finalWhiteMinusBlackScore = score; - if(finalWhiteMinusBlackScore > 0.0f) + if(finalWhiteMinusBlackScore > Global::FLOAT_EPS) winner = C_WHITE; - else if(finalWhiteMinusBlackScore < 0.0f) + else if(finalWhiteMinusBlackScore < -Global::FLOAT_EPS) winner = C_BLACK; else winner = C_EMPTY; } -void BoardHistory::getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { - if(rules.scoringRule == Rules::SCORING_AREA) - countAreaScoreWhiteMinusBlack(board,area); - else if(rules.scoringRule == Rules::SCORING_TERRITORY) - countTerritoryAreaScoreWhiteMinusBlack(board,area); - else - ASSERT_UNREACHABLE; -} - void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) { - int boardScore; - if(rules.scoringRule == Rules::SCORING_AREA) + assert(rules.isDots == board.isDots()); + + int boardScore = 0; + if(rules.isDots) + boardScore = countDotsScoreWhiteMinusBlack(board,area); + else if(rules.scoringRule == Rules::SCORING_AREA) boardScore = countAreaScoreWhiteMinusBlack(board,area); else if(rules.scoringRule == Rules::SCORING_TERRITORY) boardScore = countTerritoryAreaScoreWhiteMinusBlack(board,area); @@ -693,11 +745,12 @@ void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ ASSERT_UNREACHABLE; if(hasButton) { + assert(!board.rules.isDots); hasButton = false; whiteBonusScore += (presumedNextMovePla == P_WHITE ? 0.5f : -0.5f); } - setFinalScoreAndWinner(boardScore + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); isScored = true; isNoResult = false; isResignation = false; @@ -711,42 +764,71 @@ void BoardHistory::endAndScoreGameNow(const Board& board) { } void BoardHistory::endGameIfAllPassAlive(const Board& board) { - int boardScore = 0; - bool nonPassAliveStones = false; - bool safeBigTerritories = false; - bool unsafeBigTerritories = false; - Color area[Board::MAX_ARR_SIZE]; - board.calculateArea( - area, - nonPassAliveStones, safeBigTerritories, unsafeBigTerritories, rules.multiStoneSuicideLegal - ); + assert(rules.isDots == board.isDots()); + + if (rules.isDots) { + if (const float whiteScoreAfterGrounding = whiteScoreIfGroundingAlive(board); whiteScoreAfterGrounding != std::numeric_limits::quiet_NaN()) { + setFinalScoreAndWinner(whiteScoreAfterGrounding); + isScored = true; + isNoResult = false; + isResignation = false; + isGameFinished = true; + isPastNormalPhaseEnd = false; + isPassAliveFinished = true; + } + } else { + Color area[Board::MAX_ARR_SIZE]; + int boardScore = 0; - for(int y = 0; y 0) { - if(moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { + if(moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) { if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc] && board.getChainSize(moveLoc) == 1 && board.getNumLiberties(moveLoc) == 1) return true; Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); @@ -788,7 +879,7 @@ bool BoardHistory::isLegal(const Board& board, Loc moveLoc, Player movePla) cons if(board.isKoBanned(moveLoc)) return false; } - if(!board.isLegalIgnoringKo(moveLoc,movePla,rules.multiStoneSuicideLegal)) + if(!board.isLegal(moveLoc, movePla, rules.multiStoneSuicideLegal, true)) return false; if(superKoBanned[moveLoc]) return false; @@ -797,7 +888,10 @@ bool BoardHistory::isLegal(const Board& board, Loc moveLoc, Player movePla) cons } bool BoardHistory::isPassForKo(const Board& board, Loc moveLoc, Player movePla) const { - if(encorePhase > 0 && moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { + assert(rules.isDots == board.isDots()); + if (rules.isDots) return false; + + if(encorePhase > 0 && moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) { if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc] && board.getChainSize(moveLoc) == 1 && board.getNumLiberties(moveLoc) == 1) return true; @@ -857,6 +951,10 @@ bool BoardHistory::wouldBeSpightlikeEndingPass(Player movePla, Hash128 koHashBef } bool BoardHistory::passWouldEndPhase(const Board& board, Player movePla) const { + assert(rules.isDots == board.isDots()); + if (rules.isDots) return false; + + // TODO: probably add the assert? assert(!board.isDots()); Hash128 koHashBeforeMove = getKoHash(rules, board, movePla, encorePhase, koRecapBlockHash); if(newConsecutiveEndingPassesAfterPass() >= 2 || wouldBeSpightlikeEndingPass(movePla,koHashBeforeMove)) @@ -865,6 +963,9 @@ bool BoardHistory::passWouldEndPhase(const Board& board, Player movePla) const { } bool BoardHistory::passWouldEndGame(const Board& board, Player movePla) const { + if (board.rules.isDots) { + return true; // Pass in Dots game is grounding move that always ends the game + } return passWouldEndPhase(board,movePla) && ( rules.scoringRule == Rules::SCORING_AREA || (rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase >= 2) @@ -880,7 +981,8 @@ bool BoardHistory::shouldSuppressEndGameFromFriendlyPass(const Board& board, Pla bool BoardHistory::isFinalPhase() const { return - rules.scoringRule == Rules::SCORING_AREA + rules.isDots + || rules.scoringRule == Rules::SCORING_AREA || (rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase >= 2); } @@ -888,8 +990,9 @@ bool BoardHistory::isLegalTolerant(const Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + constexpr bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; return true; } @@ -897,8 +1000,9 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; makeBoardMoveAssumeLegal(board,moveLoc,movePla,NULL); return true; @@ -907,17 +1011,14 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla == presumedNextMovePla && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; makeBoardMoveAssumeLegal(board,moveLoc,movePla,NULL,preventEncore); return true; } -void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable) { - makeBoardMoveAssumeLegal(board,moveLoc,movePla,rootKoHashTable,false); -} - void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore) { Hash128 posHashBeforeMove = board.pos_hash; @@ -929,7 +1030,7 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); } - bool moveIsIllegal = !isLegal(board,moveLoc,movePla); + const bool moveIsIllegal = !isLegal(board,moveLoc,movePla); //And if somehow we're making a move after the game was ended, just clear those values and continue. isGameFinished = false; @@ -939,90 +1040,108 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo isScored = false; isNoResult = false; isResignation = false; + isPassAliveFinished = false; //Update consecutiveEndingPasses and button bool isSpightlikeEndingPass = false; - if(moveLoc != Board::PASS_LOC) - consecutiveEndingPasses = 0; - else if(hasButton) { - assert(encorePhase == 0 && rules.hasButton); - hasButton = false; - whiteBonusScore += (movePla == P_WHITE ? 0.5f : -0.5f); - consecutiveEndingPasses = 0; - //Taking the button clears all ko hash histories (this is equivalent to not clearing them and treating buttonless - //state as different than buttonful state) - hashesBeforeBlackPass.clear(); - hashesBeforeWhitePass.clear(); - koHashHistory.clear(); - //The first turn idx with history will be the one RESULTING from this move. - firstTurnIdxWithKoHistory = moveHistory.size()+1; - } - else { - //Passes clear ko history in the main phase with spight ko rules and in the encore - //This lifts bans in spight ko rules and lifts 3-fold-repetition checking in the encore for no-resultifying infinite cycles - //They also clear in simple ko rules for the purpose of no-resulting long cycles. Long cycles with passes do not no-result. - if(phaseHasSpightlikeEndingAndPassHistoryClearing()) { + bool wasPassForKo = false; + + if (moveLoc == Board::RESIGN_LOC) { + setWinnerByResignation(getOpp(movePla)); + } else if (rules.isDots) { + //Dots game + board.playMoveAssumeLegal(moveLoc, movePla); + if (moveLoc == Board::PASS_LOC) { + isScored = true; + isNoResult = false; + isResignation = false; + isGameFinished = true; + isPastNormalPhaseEnd = false; + const auto whiteMinusBlackScore = static_cast(board.numBlackCaptures - board.numWhiteCaptures); + setFinalScoreAndWinner(whiteMinusBlackScore + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + } + } else { + if(moveLoc != Board::PASS_LOC) + consecutiveEndingPasses = 0; + else if(hasButton) { + assert(encorePhase == 0 && rules.hasButton); + hasButton = false; + whiteBonusScore += (movePla == P_WHITE ? 0.5f : -0.5f); + consecutiveEndingPasses = 0; + //Taking the button clears all ko hash histories (this is equivalent to not clearing them and treating buttonless + //state as different than buttonful state) + hashesBeforeBlackPass.clear(); + hashesBeforeWhitePass.clear(); koHashHistory.clear(); //The first turn idx with history will be the one RESULTING from this move. firstTurnIdxWithKoHistory = moveHistory.size()+1; - //Does not clear hashesBeforeBlackPass or hashesBeforeWhitePass. Passes lift ko bans, but - //still repeated positions after pass end the game or phase, which these arrays are used to check. } + else { + //Passes clear ko history in the main phase with spight ko rules and in the encore + //This lifts bans in spight ko rules and lifts 3-fold-repetition checking in the encore for no-resultifying infinite cycles + //They also clear in simple ko rules for the purpose of no-resulting long cycles. Long cycles with passes do not no-result. + if(phaseHasSpightlikeEndingAndPassHistoryClearing()) { + koHashHistory.clear(); + //The first turn idx with history will be the one RESULTING from this move. + firstTurnIdxWithKoHistory = moveHistory.size()+1; + //Does not clear hashesBeforeBlackPass or hashesBeforeWhitePass. Passes lift ko bans, but + //still repeated positions after pass end the game or phase, which these arrays are used to check. + } - Hash128 koHashBeforeThisMove = getKoHash(rules,board,movePla,encorePhase,koRecapBlockHash); - consecutiveEndingPasses = newConsecutiveEndingPassesAfterPass(); - //Check if we have a game-ending pass BEFORE updating hashesBeforeBlackPass and hashesBeforeWhitePass - isSpightlikeEndingPass = wouldBeSpightlikeEndingPass(movePla,koHashBeforeThisMove); + Hash128 koHashBeforeThisMove = getKoHash(rules,board,movePla,encorePhase,koRecapBlockHash); + consecutiveEndingPasses = newConsecutiveEndingPassesAfterPass(); + //Check if we have a game-ending pass BEFORE updating hashesBeforeBlackPass and hashesBeforeWhitePass + isSpightlikeEndingPass = wouldBeSpightlikeEndingPass(movePla,koHashBeforeThisMove); - //Update hashesBeforeBlackPass and hashesBeforeWhitePass - if(movePla == P_BLACK) - hashesBeforeBlackPass.push_back(koHashBeforeThisMove); - else if(movePla == P_WHITE) - hashesBeforeWhitePass.push_back(koHashBeforeThisMove); - else - ASSERT_UNREACHABLE; - } - - //Handle pass-for-ko moves in the encore. Pass for ko lifts a ko recapture block and does nothing else. - bool wasPassForKo = false; - if(encorePhase > 0 && moveLoc != Board::PASS_LOC) { - if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc]) { - setKoRecapBlocked(moveLoc,false); - wasPassForKo = true; - //Clear simple ko loc just in case - //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. - board.clearSimpleKoLoc(); + //Update hashesBeforeBlackPass and hashesBeforeWhitePass + if(movePla == P_BLACK) + hashesBeforeBlackPass.push_back(koHashBeforeThisMove); + else if(movePla == P_WHITE) + hashesBeforeWhitePass.push_back(koHashBeforeThisMove); + else + ASSERT_UNREACHABLE; } - else { - Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); - if(koCaptureLoc != Board::NULL_LOC && koRecapBlocked[koCaptureLoc] && board.colors[koCaptureLoc] == getOpp(movePla)) { - setKoRecapBlocked(koCaptureLoc,false); + + //Handle pass-for-ko moves in the encore. Pass for ko lifts a ko recapture block and does nothing else. + if(encorePhase > 0 && moveLoc != Board::PASS_LOC) { + if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc]) { + setKoRecapBlocked(moveLoc,false); wasPassForKo = true; //Clear simple ko loc just in case //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. board.clearSimpleKoLoc(); } - } - } - //Otherwise handle regular moves - if(!wasPassForKo) { - board.playMoveAssumeLegal(moveLoc,movePla); - - if(encorePhase > 0) { - //Update ko recapture blocks and record that this was a ko capture - if(board.ko_loc != Board::NULL_LOC) { - setKoRecapBlocked(moveLoc,true); - koCapturesInEncore.push_back(EncoreKoCapture(posHashBeforeMove,moveLoc,movePla)); - //Clear simple ko loc now that we've absorbed the ko loc information into the korecap blocks - //Once we have that, the simple ko loc plays no further role in game state or legality - board.clearSimpleKoLoc(); + else { + Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); + if(koCaptureLoc != Board::NULL_LOC && koRecapBlocked[koCaptureLoc] && board.colors[koCaptureLoc] == getOpp(movePla)) { + setKoRecapBlocked(koCaptureLoc,false); + wasPassForKo = true; + //Clear simple ko loc just in case + //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. + board.clearSimpleKoLoc(); + } } - //Unmark all ko recap blocks not on stones - for(int y = 0; y 0) { + //Update ko recapture blocks and record that this was a ko capture + if(board.ko_loc != Board::NULL_LOC) { + setKoRecapBlocked(moveLoc,true); + koCapturesInEncore.push_back(EncoreKoCapture(posHashBeforeMove,moveLoc,movePla)); + //Clear simple ko loc now that we've absorbed the ko loc information into the korecap blocks + //Once we have that, the simple ko loc plays no further role in game state or legality + board.clearSimpleKoLoc(); + } + //Unmark all ko recap blocks not on stones + for(int y = 0; y 0) { - //During the encore, only one capture of each ko in a given position by a given player - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - for(size_t i = 0; i 0) { + //During the encore, only one capture of each ko in a given position by a given player + std::fill(superKoBanned.begin(), superKoBanned.end(), false); + for(size_t i = 0; i= 2 || isSpightlikeEndingPass) { - if(rules.scoringRule == Rules::SCORING_AREA) { - assert(encorePhase <= 0); - endAndScoreGameNow(board); + //Handicap bonus score + if(movePla == P_WHITE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) + whiteHasMoved = true; + if(assumeMultipleStartingBlackMovesAreHandicap && !whiteHasMoved && movePla == P_BLACK && rules.whiteHandicapBonusRule != Rules::WHB_ZERO) { + whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); } - else if(rules.scoringRule == Rules::SCORING_TERRITORY) { - if(encorePhase >= 2) + + //Phase transitions and game end + if(consecutiveEndingPasses >= 2 || isSpightlikeEndingPass) { + if(rules.scoringRule == Rules::SCORING_AREA) { + assert(encorePhase <= 0); endAndScoreGameNow(board); - else { - if(preventEncore) { - isPastNormalPhaseEnd = true; - //Cap at 1 - do include just the single pass here by itself since the single pass by itself - //absent any history of passes before that should be valid still. - numApproxValidTurnsThisPhase = std::min(numApproxValidTurnsThisPhase,1); - numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); - } + } + else if(rules.scoringRule == Rules::SCORING_TERRITORY) { + if(encorePhase >= 2) + endAndScoreGameNow(board); else { - encorePhase += 1; - numTurnsThisPhase = 0; - numApproxValidTurnsThisPhase = 0; - if(encorePhase == 2) - std::copy(board.colors, board.colors+Board::MAX_ARR_SIZE, secondEncoreStartColors); - - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - consecutiveEndingPasses = 0; - hashesBeforeBlackPass.clear(); - hashesBeforeWhitePass.clear(); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - koRecapBlockHash = Hash128(); - koCapturesInEncore.clear(); - - koHashHistory.clear(); - koHashHistory.push_back(getKoHash(rules,board,getOpp(movePla),encorePhase,koRecapBlockHash)); - //The first ko hash history is the one for the move we JUST appended to the move history earlier. - firstTurnIdxWithKoHistory = moveHistory.size(); + if(preventEncore) { + isPastNormalPhaseEnd = true; + //Cap at 1 - do include just the single pass here by itself since the single pass by itself + //absent any history of passes before that should be valid still. + numApproxValidTurnsThisPhase = std::min(numApproxValidTurnsThisPhase,1); + numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); + } + else { + encorePhase += 1; + numTurnsThisPhase = 0; + numApproxValidTurnsThisPhase = 0; + if(encorePhase == 2) + std::copy_n(board.colors, Board::MAX_ARR_SIZE, secondEncoreStartColors.begin()); + + std::fill(superKoBanned.begin(), superKoBanned.end(), false); + consecutiveEndingPasses = 0; + hashesBeforeBlackPass.clear(); + hashesBeforeWhitePass.clear(); + std::fill(koRecapBlocked.begin(), koRecapBlocked.end(), false); + koRecapBlockHash = Hash128(); + koCapturesInEncore.clear(); + + koHashHistory.clear(); + koHashHistory.push_back(getKoHash(rules,board,getOpp(movePla),encorePhase,koRecapBlockHash)); + //The first ko hash history is the one for the move we JUST appended to the move history earlier. + firstTurnIdxWithKoHistory = moveHistory.size(); + } } } + else + ASSERT_UNREACHABLE; } - else - ASSERT_UNREACHABLE; - } - //Break long cycles with no-result - if(moveLoc != Board::PASS_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { - if(numberOfKoHashOccurrencesInHistory(koHashHistory[koHashHistory.size()-1], rootKoHashTable) >= 3) { - isNoResult = true; - isGameFinished = true; + //Break long cycles with no-result + if(moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { + if(numberOfKoHashOccurrencesInHistory(koHashHistory[koHashHistory.size()-1], rootKoHashTable) >= 3) { + isNoResult = true; + isGameFinished = true; + } } } - } bool BoardHistory::hasBlackPassOrWhiteFirst() const { //First move was made by white this game, on an empty board. - if(initialBoard.isEmpty() && moveHistory.size() > 0 && moveHistory[0].pla == P_WHITE) + if(initialBoard.isStartPos() && moveHistory.size() > 0 && moveHistory[0].pla == P_WHITE) return true; //Black passed exactly once or white doublemoved int numBlackPasses = 0; @@ -1181,6 +1303,9 @@ Hash128 BoardHistory::getSituationAndSimpleKoHash(const Board& board, Player nex //Note that board.pos_hash also incorporates the size of the board. Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; + if (board.isDots()) { + assert(board.ko_loc == Board::NULL_LOC); + } if(board.ko_loc != Board::NULL_LOC) hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; return hash; @@ -1190,6 +1315,9 @@ Hash128 BoardHistory::getSituationAndSimpleKoAndPrevPosHash(const Board& board, //Note that board.pos_hash also incorporates the size of the board. Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; + if (board.isDots()) { + assert(board.ko_loc == Board::NULL_LOC); + } if(board.ko_loc != Board::NULL_LOC) hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; @@ -1209,37 +1337,39 @@ Hash128 BoardHistory::getSituationRulesAndKoHash(const Board& board, const Board Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; - assert(hist.encorePhase >= 0 && hist.encorePhase <= 2); - hash ^= Board::ZOBRIST_ENCORE_HASH[hist.encorePhase]; - - if(hist.encorePhase == 0) { - if(board.ko_loc != Board::NULL_LOC) - hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; - for(int y = 0; y= 0 && hist.encorePhase <= 2); + hash ^= Board::ZOBRIST_ENCORE_HASH[hist.encorePhase]; + + if(hist.encorePhase == 0) { + if(board.ko_loc != Board::NULL_LOC) + hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; + for(int y = 0; y recentBoards; int currentRecentBoardIdx; Player presumedNextMovePla; //Did this board location ever have a stone there before, or was it ever played? //(Also includes locations of suicides) - bool wasEverOccupiedOrPlayed[Board::MAX_ARR_SIZE]; + std::vector wasEverOccupiedOrPlayed; //Locations where the next player is not allowed to play due to superko - bool superKoBanned[Board::MAX_ARR_SIZE]; + std::vector superKoBanned; //Number of consecutive passes made that count for ending the game or phase int consecutiveEndingPasses; @@ -68,7 +68,7 @@ struct BoardHistory { int numConsecValidTurnsThisGame; //Ko-recapture-block locations for territory scoring in encore - bool koRecapBlocked[Board::MAX_ARR_SIZE]; + std::vector koRecapBlocked; Hash128 koRecapBlockHash; //Hash contribution from ko-recap-block locations in encore. //Used to implement once-only rules for ko captures in encore @@ -76,7 +76,7 @@ struct BoardHistory { std::vector koCapturesInEncore; //State of the grid as of the start of encore phase 2 for territory scoring - Color secondEncoreStartColors[Board::MAX_ARR_SIZE]; + std::vector secondEncoreStartColors; //Amount that should be added to komi float whiteBonusScore; @@ -100,9 +100,13 @@ struct BoardHistory { bool isNoResult; //True if this game is supposed to be ended but it was by resignation rather than an actual end position bool isResignation; + //True if this game is supposed to be ended early + bool isPassAliveFinished; BoardHistory(); + explicit BoardHistory(const Rules& newRules); ~BoardHistory(); + BoardHistory(const Board& board); BoardHistory(const Board& board, Player pla, const Rules& rules, int encorePhase); @@ -158,8 +162,7 @@ struct BoardHistory { //even if the move violates superko or encore ko recapture prohibitions, or is past when the game is ended. //This allows for robustness when this code is being used for analysis or with external data sources. //preventEncore artifically prevents any move from entering or advancing the encore phase when using territory scoring. - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable); - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore); + void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore = false); //Make a move with legality checking, but be mostly tolerant and allow moves that can still be handled but that may not technically //be legal. This is intended for reading moves from SGFs and such where maybe we're getting moves that were played in a different //ruleset than ours. Returns true if successful, false if was illegal even unter tolerant rules. @@ -168,11 +171,17 @@ struct BoardHistory { bool isLegalTolerant(const Board& board, Loc moveLoc, Player movePla) const; //Slightly expensive, check if the entire game is all pass-alive-territory, and if so, declare the game finished + // For Dots game it's Grounding alive void endGameIfAllPassAlive(const Board& board); + void endGameIfNoLegalMoves(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); + // Effective draw is when there are no ungrounded dots on the field (disregarding Komi) + // We can consider grounding in this case because the further game typically doesn't make sense. + bool winOrEffectiveDrawByGrounding(const Board& board, Player pla, bool considerDraw = true) const; + // Return > 0 if white wins by grounding, < 0 if black wins by grounding, 0 if there are no ungrounded dots and Nan otherwise + float whiteScoreIfGroundingAlive(const Board& board) const; void endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]); - void getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; void setWinnerByResignation(Player pla); @@ -200,6 +209,7 @@ struct BoardHistory { private: bool koHashOccursInHistory(Hash128 koHash, const KoHashTable* rootKoHashTable) const; void setKoRecapBlocked(Loc loc, bool b); + int countDotsScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; int countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; int countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; void setFinalScoreAndWinner(float score); diff --git a/cpp/game/common.h b/cpp/game/common.h new file mode 100644 index 000000000..833e3c0cd --- /dev/null +++ b/cpp/game/common.h @@ -0,0 +1,52 @@ +#ifndef GAME_COMMON_H +#define GAME_COMMON_H +#include +#include "../core/global.h" + +const std::string DOTS_KEY = "dots"; +const std::string DATA_LEN_X_KEY = "dataBoardLenX"; +const std::string DATA_LEN_Y_KEY = "dataBoardLenY"; +const std::string DOTS_CAPTURE_EMPTY_BASE_KEY = "dotsCaptureEmptyBase"; +const std::string DOTS_CAPTURE_EMPTY_BASES_KEY = "dotsCaptureEmptyBases"; +const std::string START_POS_KEY = "startPos"; +const std::string START_POS_RANDOM_KEY = "startPosIsRandom"; +const std::string START_POSES_KEY = "startPoses"; +const std::string START_POSES_ARE_RANDOM_KEY = "startPosesAreRandom"; + +const std::string BLACK_SCORE_IF_WHITE_GROUNDS_KEY = "blackScoreIfWhiteGrounds"; +const std::string WHITE_SCORE_IF_BLACK_GROUNDS_KEY = "whiteScoreIfBlackGrounds"; + +const std::string PLAYER1 = "Player1"; +const std::string PLAYER2 = "Player2"; +const std::string PLAYER1_SHORT = "P1"; +const std::string PLAYER2_SHORT = "P2"; + +const std::string RESIGN_STR = "resign"; + +// Player +typedef int8_t Player; +static constexpr Player P_BLACK = 1; +static constexpr Player P_WHITE = 2; + +//Color of a point on the board +typedef int8_t Color; +static constexpr Color C_EMPTY = 0; +static constexpr Color C_BLACK = 1; +static constexpr Color C_WHITE = 2; +static constexpr Color C_WALL = 3; +static constexpr int NUM_BOARD_COLORS = 4; + +typedef int8_t State; + +//Location of a point on the board +//(x,y) is represented as (x+1) + (y+1)*(x_size+1) +typedef short Loc; + +//Simple structure for storing moves. This is a convenient place to define it. +STRUCT_NAMED_PAIR(Loc,loc,Player,pla,Move); + +inline bool movesEqual(const Move& m1, const Move& m2) { + return m1.loc == m2.loc && m1.pla == m2.pla; +} + +#endif diff --git a/cpp/game/dotsboardhistory.cpp b/cpp/game/dotsboardhistory.cpp new file mode 100644 index 000000000..3736e7f8e --- /dev/null +++ b/cpp/game/dotsboardhistory.cpp @@ -0,0 +1,50 @@ +#include "../game/boardhistory.h" + +using namespace std; + +int BoardHistory::countDotsScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots); + + const float whiteScore = whiteScoreIfGroundingAlive(board); + Color groundingPlayer = C_EMPTY; + if (whiteScore >= 0.0f) { // Don't use EPS comparison because in case of zero result, there is a draw and any player can ground + groundingPlayer = C_WHITE; + } else if (whiteScore < 0.0f) { + groundingPlayer = C_BLACK; + } + return board.calculateOwnershipAndWhiteScore(area, groundingPlayer); +} + +bool BoardHistory::winOrEffectiveDrawByGrounding(const Board& board, const Player pla, const bool considerDraw) const { + assert(rules.isDots); + + const float whiteScore = whiteScoreIfGroundingAlive(board); + return (considerDraw && Global::isZero(whiteScore)) || + (pla == P_WHITE && whiteScore > Global::FLOAT_EPS) || + (pla == P_BLACK && whiteScore < -Global::FLOAT_EPS); +} + +float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { + assert(rules.isDots); + + const float fullWhiteScoreIfBlackGrounds = + static_cast(board.whiteScoreIfBlackGrounds) + whiteBonusScore + whiteHandicapBonusScore + rules.komi; + if (fullWhiteScoreIfBlackGrounds < -Global::FLOAT_EPS) { + // Black already won the game + return fullWhiteScoreIfBlackGrounds; + } + + const float fullBlackScoreIfWhiteGrounds = + static_cast(board.blackScoreIfWhiteGrounds) - whiteBonusScore - whiteHandicapBonusScore - rules.komi; + if (fullBlackScoreIfWhiteGrounds < -Global::FLOAT_EPS) { + // White already won the game + return -fullBlackScoreIfWhiteGrounds; + } + + if (Global::isZero(fullWhiteScoreIfBlackGrounds) && Global::isZero(fullBlackScoreIfWhiteGrounds)) { + // Draw by grounding + return 0.0f; + } + + return std::numeric_limits::quiet_NaN(); +} \ No newline at end of file diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp new file mode 100644 index 000000000..544fa3b5b --- /dev/null +++ b/cpp/game/dotsfield.cpp @@ -0,0 +1,951 @@ +#include "board.h" +#include + +#include "../program/play.h" + +using namespace std; + +static constexpr int PLACED_PLAYER_SHIFT = PLAYER_BITS_COUNT; +static constexpr int EMPTY_TERRITORY_SHIFT = PLACED_PLAYER_SHIFT + PLAYER_BITS_COUNT; +static constexpr int TERRITORY_FLAG_SHIFT = EMPTY_TERRITORY_SHIFT + PLAYER_BITS_COUNT; +static constexpr int GROUNDED_FLAG_SHIFT = TERRITORY_FLAG_SHIFT + 1; + +static constexpr State TERRITORY_FLAG = 1 << TERRITORY_FLAG_SHIFT; +static constexpr State GROUNDED_FLAG = static_cast(1 << GROUNDED_FLAG_SHIFT); +static constexpr State INVALIDATE_TERRITORY_MASK = ~(ACTIVE_MASK | ACTIVE_MASK << EMPTY_TERRITORY_SHIFT); + +Loc Location::xm1y(Loc loc) { + return loc - 1; +} + +Loc Location::xm1ym1(Loc loc, int x_size) { + return loc - 1 - (x_size+1); +} + +Loc Location::xym1(Loc loc, int x_size) { + return loc - (x_size+1); +} + +Loc Location::xp1ym1(Loc loc, int x_size) { + return loc + 1 - (x_size+1); +} + +Loc Location::xp1y(Loc loc) { + return loc + 1; +} + +Loc Location::xp1yp1(Loc loc, int x_size) { + return loc + 1 + (x_size+1); +} + +Loc Location::xyp1(Loc loc, int x_size) { + return loc + (x_size+1); +} + +Loc Location::xm1yp1(Loc loc, int x_size) { + return loc - 1 + (x_size+1); +} + +int Location::getGetBigJumpInitialIndex(const Loc loc0, const Loc loc1, const int x_size) { + const int diff = loc1 - loc0; + const int stride = x_size + 1; + + if (diff == -1 || diff == -1 - stride) { + return RIGHT_TOP_INDEX; + } + + if(diff == -stride || diff == +1 - stride) { + return RIGHT_BOTTOM_INDEX; + } + + if(diff == 1 || diff == 1 + stride) { + return LEFT_BOTTOM_INDEX; + } + + if(diff == stride || diff == -1 + stride) { + return LEFT_TOP_INDEX; + } + + return -1; +} + +Loc Location::getNextLocCW(const Loc loc0, const Loc loc1, const int x_size) { + const int diff = loc1 - loc0; + const int stride = x_size + 1; + + if (diff == -1) return xm1ym1(loc0, x_size); + if (diff == -1 - stride) return xym1(loc0, x_size); + if (diff == -stride) return xp1ym1(loc0, x_size); + if (diff == +1 - stride) return xp1y(loc0); + if (diff == +1) return xp1yp1(loc0, x_size); + if (diff == +1 + stride) return xyp1(loc0, x_size); + if (diff == +stride) return xm1yp1(loc0, x_size); + if (diff == -1 + stride) return xm1y(loc0); + + assert(false && "Incorrect locations"); + return 0; +} + +Color getActiveColor(const State state) { + return static_cast(state & ACTIVE_MASK); +} + +bool isTerritory(const State s) { + return (s & TERRITORY_FLAG) == TERRITORY_FLAG; +} + +Color getPlacedDotColor(const State s) { + return static_cast(s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK); +} + +static bool isPlaced(const State s, const Player pla) { + return (s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK) == pla; +} + +static bool isActive(const State s, const Player pla) { + return (s & ACTIVE_MASK) == pla; +} + +static State setTerritoryAndActivePlayer(const State s, const Player pla) { + const int invalidateTerritoryValue = s & INVALIDATE_TERRITORY_MASK; + return static_cast(TERRITORY_FLAG | invalidateTerritoryValue | pla); +} + +Color getEmptyTerritoryColor(const State s) { + return static_cast(s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK); +} + +static bool isWithinEmptyTerritory(const State s, const Player pla) { + return (s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK) == pla; +} + +State Board::getState(const Loc loc) const { + return colors[loc]; +} + +void Board::setState(const Loc loc, const State state) { + colors[loc] = state; +} + +bool Board::isDots() const { + return rules.isDots; +} + +bool isGrounded(const State state) { + return (state & GROUNDED_FLAG) == GROUNDED_FLAG; +} + +static bool isGroundedOrWall(const State state, const Player pla) { + // Use bit tricks for grounding detecting. + // If the active player is C_WALL, then the result is also true. + return (state & GROUNDED_FLAG) == GROUNDED_FLAG && (state & pla) == pla; +} + +void Board::setGrounded(const Loc loc) { + colors[loc] = static_cast(colors[loc] | GROUNDED_FLAG); +} + +void Board::clearGrounded(const Loc loc) { + colors[loc] = static_cast(colors[loc] & ~GROUNDED_FLAG); +} + +bool Board::isVisited(const Loc loc) const { + return visited_data[loc]; +} + +void Board::setVisited(const Loc loc) const { + visited_data[loc] = true; +} + +void Board::clearVisited(const Loc loc) const { + visited_data[loc] = false; +} + +void Board::clearVisited(const vector& locations) const { + for (const Loc& loc : locations) { + clearVisited(loc); + } +} + +int Board::calculateOwnershipAndWhiteScore(Color* result, const Color groundingPlayer) const { + [[maybe_unused]] int whiteCaptures = 0; + [[maybe_unused]] int blackCaptures = 0; + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + const State state = getState(loc); + const Color activeColor = getActiveColor(state); + const Color placedDotColor = getPlacedDotColor(state); + Color ownershipColor = C_EMPTY; + if (activeColor != C_EMPTY) { + if (isGrounded(state) || groundingPlayer == C_EMPTY) { + if (placedDotColor != C_EMPTY && activeColor != placedDotColor) { + ownershipColor = activeColor; + if (placedDotColor == P_BLACK) { + blackCaptures++; + } else { + whiteCaptures++; + } + } + } else { + // If the game is finished by grounding by a player, + // Remove its ungrounded dots to get a more refined ownership and score. + if (groundingPlayer == C_WHITE && placedDotColor == P_WHITE) { + ownershipColor = P_BLACK; + whiteCaptures++; + } else if (groundingPlayer == C_BLACK && placedDotColor == P_BLACK) { + ownershipColor = P_WHITE; + blackCaptures++; + } + } + } + result[loc] = ownershipColor; + } + } + + if (groundingPlayer == C_WHITE) { + // White wins by grounding + assert(blackScoreIfWhiteGrounds == whiteCaptures - blackCaptures); + return -blackScoreIfWhiteGrounds; + } + + if (groundingPlayer == C_BLACK) { + // Black wins by grounding + assert(whiteScoreIfBlackGrounds == blackCaptures - whiteCaptures); + return whiteScoreIfBlackGrounds; + } + + assert(numBlackCaptures == blackCaptures && numWhiteCaptures == whiteCaptures); + return numBlackCaptures - numWhiteCaptures; +} + +Board::MoveRecord::MoveRecord( + const Loc newLoc, + const Player newPla, + const State newPreviousState, + const vector& newBases, + const vector& newEmptyBaseInvalidateLocations, + const vector& newGroundingLocations +) { + ko_loc = NULL_LOC; + capDirs = 0; + + loc = newLoc; + pla = newPla; + previousState = newPreviousState; + bases = newBases; + emptyBaseInvalidateLocations = newEmptyBaseInvalidateLocations; + groundingLocations = newGroundingLocations; +} + +bool Board::isSuicideDots(const Loc loc, const Player pla) const { + const State state = getState(loc); + if (Player opp = getOpp(pla); getActiveColor(state) == C_EMPTY && getEmptyTerritoryColor(state) == opp) { + return !wouldBeCaptureDots(loc, pla); + } + + return false; +} + +bool Board::wouldBeCaptureDots(const Loc loc, const Player pla) const { + // TODO: optimize and get rid of `const_cast` + auto moveRecord = const_cast(this)->tryPlayMoveRecordedDots(loc, pla, false); + + bool result = false; + + if (moveRecord.pla != C_EMPTY) { + for (const Base& base : moveRecord.bases) { + if (base.is_real && base.pla == pla) { + result = true; + break; + } + } + + const_cast(this)->undoDots(moveRecord); + } + + return result; +} + +Board::MoveRecord Board::playMoveRecordedDots(const Loc loc, const Player pla) { + const MoveRecord& result = tryPlayMoveRecordedDots(loc, pla, true); + assert(result.pla == pla); + return result; +} + +void Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { + const State originalState = getState(loc); + + if (loc == RESIGN_LOC) { + } else if (loc == PASS_LOC) { + auto initEmptyBaseInvalidateLocations = vector(); + auto bases = vector(); + ground(pla, initEmptyBaseInvalidateLocations, bases); + } else { + colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); + const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; + pos_hash ^= hashValue; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed--; + } + + bool atLeastOneRealBaseIsGrounded = false; + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, pla, unconnectedLocationsSize); + bool capturing = false; + if (unconnectedLocationsSize >= 2) { + auto bases = vector(); + tryCapture(loc, pla, unconnectedLocations, unconnectedLocationsSize, atLeastOneRealBaseIsGrounded, bases); + capturing = !bases.empty(); + } + + const Color opp = getOpp(pla); + if (!capturing) { + if (getEmptyTerritoryColor(originalState) == opp) { + vector oppBases; + captureWhenEmptyTerritoryBecomesRealBase(loc, opp, oppBases, atLeastOneRealBaseIsGrounded); + } + } else if (isWithinEmptyTerritory(originalState, opp)) { + invalidateAdjacentEmptyTerritoryIfNeeded(loc); + } + + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else if (pla == P_WHITE) { + blackScoreIfWhiteGrounds++; + } + + if (atLeastOneRealBaseIsGrounded) { + fillGrounding(loc); + } else if( + const Player locActivePlayer = getColor(loc); // Can't use pla because of a possible suicidal move + isGroundedOrWall(getState(Location::xm1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xym1(loc, x_size)), locActivePlayer) || + isGroundedOrWall(getState(Location::xp1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xyp1(loc, x_size)), locActivePlayer) + ) { + fillGrounding(loc); + } + } +} + +Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool isSuicideLegal) { + State originalState = getState(loc); + + vector bases; + vector initEmptyBaseInvalidateLocations; + vector newGroundingLocations; + + if (loc == RESIGN_LOC) { + } else if (loc == PASS_LOC) { + ground(pla, initEmptyBaseInvalidateLocations, bases); + } else { + colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); + const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; + pos_hash ^= hashValue; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed--; + } + + bool atLeastOneRealBaseIsGrounded = false; + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, pla, unconnectedLocationsSize); + if (unconnectedLocationsSize >= 2) { + tryCapture(loc, pla, unconnectedLocations, unconnectedLocationsSize, atLeastOneRealBaseIsGrounded, bases); + } + + const Color opp = getOpp(pla); + if (bases.empty()) { + if (getEmptyTerritoryColor(originalState) == opp) { + if (isSuicideLegal) { + captureWhenEmptyTerritoryBecomesRealBase(loc, opp, bases, atLeastOneRealBaseIsGrounded); + } else { + colors[loc] = originalState; + pos_hash ^= hashValue; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed++; + } + return {}; + } + } + } else if (isWithinEmptyTerritory(originalState, opp)) { + invalidateAdjacentEmptyTerritoryIfNeeded(loc); + initEmptyBaseInvalidateLocations = vector(closureOrInvalidateLocsBuffer); + } + + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else if (pla == P_WHITE) { + blackScoreIfWhiteGrounds++; + } + + if (atLeastOneRealBaseIsGrounded) { + newGroundingLocations = fillGrounding(loc); + } else if( + const Player locActivePlayer = getColor(loc); // Can't use pla because of a possible suicidal move + isGroundedOrWall(getState(Location::xm1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xym1(loc, x_size)), locActivePlayer) || + isGroundedOrWall(getState(Location::xp1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xyp1(loc, x_size)), locActivePlayer) + ) { + newGroundingLocations = fillGrounding(loc); + } + } + + return {loc, pla, originalState, bases, initEmptyBaseInvalidateLocations, newGroundingLocations}; +} + +void Board::undoDots(MoveRecord& moveRecord) { + if (moveRecord.loc == RESIGN_LOC) { + return; // Resin doesn't really change the state + } + + const bool isGroundingMove = moveRecord.loc == PASS_LOC; + + for (const Loc& loc : moveRecord.groundingLocations) { + const State state = getState(loc); + const Player mainPla = getActiveColor(state); + if (getPlacedDotColor(state) != C_EMPTY) { + if (mainPla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else { + blackScoreIfWhiteGrounds++; + } + } + clearGrounded(loc); + } + + if (!isGroundingMove) { + if (moveRecord.pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else if (moveRecord.pla == P_WHITE) { + blackScoreIfWhiteGrounds--; + } + } + + const Player emptyTerritoryPlayer = isGroundingMove ? moveRecord.pla : getOpp(moveRecord.pla); + for (const Loc& loc : moveRecord.emptyBaseInvalidateLocations) { + assert(0 == getState(loc)); + setState(loc, static_cast(emptyTerritoryPlayer << EMPTY_TERRITORY_SHIFT)); + } + + for (auto it = moveRecord.bases.rbegin(); it != moveRecord.bases.rend(); ++it) { + for (size_t index = 0; index < it->rollback_locations.size(); index++) { + const State rollbackState = it->rollback_states[index]; + const Loc rollbackLocation = it->rollback_locations[index]; + setState(rollbackLocation, rollbackState); + if (it->is_real) { + updateScoreAndHashForTerritory(rollbackLocation, rollbackState, it->pla, true); + } + } + } + + if (!isGroundingMove) { + setState(moveRecord.loc, moveRecord.previousState); + pos_hash ^= ZOBRIST_BOARD_HASH[moveRecord.loc][moveRecord.pla]; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed++; + } + } +} + +vector Board::fillGrounding(const Loc loc) { + vector groundedLocs; + + walkStack.clear(); + walkStack.push_back(loc); + const Player pla = getColor(loc); + assert(pla != C_EMPTY && pla != C_WALL); + setGrounded(loc); + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else { + blackScoreIfWhiteGrounds--; + } + groundedLocs.push_back(loc); + + while (!walkStack.empty()) { + const Loc currentLoc = walkStack.back(); + walkStack.pop_back(); + + forEachAdjacent(currentLoc, [&](const Loc adj) { + if (const State state = getState(adj); !isGrounded(state) && isActive(state, pla)) { + setGrounded(adj); + if (const Player placedColor = getPlacedDotColor(state); placedColor != C_EMPTY) { + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else { + blackScoreIfWhiteGrounds--; + } + } + groundedLocs.push_back(adj); + walkStack.push_back(adj); + } + }); + } + + return groundedLocs; +} + +void Board::captureWhenEmptyTerritoryBecomesRealBase( + const Loc initLoc, + const Player opp, + vector& bases, + bool& isGrounded) { + Loc loc = initLoc; + + // Searching for an opponent dot that makes a closure that contains the `initialPosition`. + // The closure always exists, otherwise there is an error in previous calculations. + while (loc > 0) { + loc = Location::xm1y(loc); + + // Try to peek an active opposite player dot + if (getColor(loc) != opp) continue; + + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, opp, unconnectedLocationsSize); + if (unconnectedLocationsSize >= 1) { + bases.clear(); + tryCapture(loc, opp, unconnectedLocations, unconnectedLocationsSize, isGrounded, bases); + // The found base always should be real and include the `iniLoc` + for (const Base& oppBase : bases) { + if (oppBase.is_real) { + return; + } + } + } + } + + assert(false && "Opp empty territory should be enclosed by an outer closure"); +} + +void Board::tryCapture( + const Loc loc, + const Player pla, + const std::array& unconnectedLocations, + const int unconnectedLocationsSize, + bool& atLeastOneRealBaseIsGrounded, + std::vector& bases) { + auto currentClosures = vector>(); + + for (int index = 0; index < unconnectedLocationsSize; index++) { + const Loc unconnectedLoc = unconnectedLocations[index]; + + // Optimization: it doesn't make sense to check the latest unconnected dot + // when all previous connections form minimal bases + // because the latest always forms a base with maximal square that should be dropped + if (const size_t closuresSize = currentClosures.size(); + closuresSize > 0 && closuresSize == unconnectedLocations.size() - 1) { + break; + } + + tryGetCounterClockwiseClosure(loc, unconnectedLoc, pla); + + // Sort the given closures in ascending order + if (!closureOrInvalidateLocsBuffer.empty()) { + bool added = false; + + auto newClosure = vector(closureOrInvalidateLocsBuffer); + + for (auto it = currentClosures.begin(); it != currentClosures.end(); ++it) { + if (closureOrInvalidateLocsBuffer.size() < it->size()) { + currentClosures.insert(it, newClosure); + added = true; + break; + } + } + + if (!added) { + currentClosures.emplace_back(newClosure); + } + } + } + + atLeastOneRealBaseIsGrounded = false; + for (const vector& currentClosure: currentClosures) { + Base base = buildBase(currentClosure, pla); + bases.push_back(base); + + if (!atLeastOneRealBaseIsGrounded && base.is_real) { + for (const Loc& closureLoc : currentClosure) { + forEachAdjacent(closureLoc, [&](const Loc adj) { + atLeastOneRealBaseIsGrounded = atLeastOneRealBaseIsGrounded || isGroundedOrWall(getState(adj), base.pla); + }); + if (atLeastOneRealBaseIsGrounded) { + break; + } + } + } + } +} + +void Board::ground(const Player pla, vector& emptyBaseInvalidatePositions, vector& bases) { + const Color opp = getOpp(pla); + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + if (const State state = getState(loc); !isGrounded(state) && isActive(state, pla)) { + bool createRealBase = false; + getTerritoryLocations(pla, loc, true, createRealBase); + assert(createRealBase); + + for (const Loc& territoryLoc : territoryLocationsBuffer) { + invalidateAdjacentEmptyTerritoryIfNeeded(territoryLoc); + for (const Loc& invalidateLoc : closureOrInvalidateLocsBuffer) { + emptyBaseInvalidatePositions.push_back(invalidateLoc); + } + } + + bases.push_back(createBaseAndUpdateStates(opp, true)); + } + } + } +} + +std::array Board::getUnconnectedLocations(const Loc loc, const Player pla, int& size) const { + const Loc xm1y = Location::xm1y(loc); + const Loc xym1 = Location::xym1(loc, x_size); + const Loc xp1y = Location::xp1y(loc); + const Loc xyp1 = Location::xyp1(loc, x_size); + + std::array unconnectedLocationsBuffer; + size = 0; + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xp1y), pla, Location::xp1yp1(loc, x_size), xyp1); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xyp1), pla, Location::xm1yp1(loc, x_size), xm1y); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xm1y), pla, Location::xm1ym1(loc, x_size), xym1); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xym1), pla, Location::xp1ym1(loc, x_size), xp1y); + + return unconnectedLocationsBuffer; +} + +void Board::checkAndAddUnconnectedLocation(std::array& unconnectedLocationsBuffer, int& size, const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { + if (checkPla != currentPla) { + if (getColor(addLoc1) == currentPla) { + unconnectedLocationsBuffer[size++] = addLoc1; + } else if (getColor(addLoc2) == currentPla) { + unconnectedLocationsBuffer[size++] = addLoc2; + } + } +} + +void Board::tryGetCounterClockwiseClosure(const Loc initialLoc, const Loc startLoc, const Player pla) const { + closureOrInvalidateLocsBuffer.clear(); + closureOrInvalidateLocsBuffer.push_back(initialLoc); + setVisited(initialLoc); + closureOrInvalidateLocsBuffer.push_back(startLoc); + setVisited(startLoc); + + Loc currentLoc = startLoc; + Loc nextLoc = initialLoc; + Loc loc; + + do { + const int initialIndex = Location::getGetBigJumpInitialIndex(currentLoc, nextLoc, x_size); + int currentIndex = initialIndex; + + bool breakSearchingLoop = false; + + do { + loc = currentLoc + adj_offsets[currentIndex++]; + if (currentIndex == 8) currentIndex = 0; + + const State state = getState(loc); + const Color activeColor = getActiveColor(state); + if (activeColor == C_WALL) { + // Optimization: there is no need to walk anymore because the border can't enclosure anything + breakSearchingLoop = true; + break; + } + + if(activeColor == pla) { + if(loc == initialLoc) { + breakSearchingLoop = true; + break; + } + + if(isVisited(loc)) { + // Remove trailing dots + Loc lastLoc; + do { + lastLoc = closureOrInvalidateLocsBuffer.back(); + closureOrInvalidateLocsBuffer.pop_back(); + clearVisited(lastLoc); + } while(lastLoc != loc); + } + + closureOrInvalidateLocsBuffer.push_back(loc); + setVisited(loc); + nextLoc = currentLoc; + currentLoc = loc; + break; + } + } while (currentIndex != initialIndex); + + if (breakSearchingLoop) { + break; + } + } + while (true); + + if (loc != initialLoc || closureOrInvalidateLocsBuffer.size() < 4) { + clearVisited(closureOrInvalidateLocsBuffer); + closureOrInvalidateLocsBuffer.clear(); + return; + } + + int square = 0; + const int stride = x_size + 1; + + const Loc prevLoc = closureOrInvalidateLocsBuffer.back(); + // Store the previously calculated coordinates because division is an expensive operation + int prevX = prevLoc % stride; + int prevY = prevLoc / stride; + + for (const Loc& l : closureOrInvalidateLocsBuffer) { + const int x = l % stride; + const int y = l / stride; + + square += prevY * x - y * prevX; + + prevX = x; + prevY = y; + clearVisited(l); + } + + if (square <= 0) { + closureOrInvalidateLocsBuffer.clear(); + } +} + +Board::Base Board::buildBase(const vector& closure, const Player pla) { + for (const Loc& closureLoc : closure) { + setVisited(closureLoc); + } + + const Loc territoryFirstLoc = Location::getNextLocCW(closure.at(1), closure.at(0), x_size); + bool createRealBase; + getTerritoryLocations(pla, territoryFirstLoc, false, createRealBase); + clearVisited(closure); + + return createBaseAndUpdateStates(pla, createRealBase); +} + +void Board::getTerritoryLocations(const Player pla, const Loc firstLoc, const bool grounding, bool& createRealBase) const { + walkStack.clear(); + territoryLocationsBuffer.clear(); + + createRealBase = grounding ? false : rules.dotsCaptureEmptyBases; + const Player opp = getOpp(pla); + + State state = getState(firstLoc); + Color activeColor = getActiveColor(state); + assert(activeColor != C_WALL); + + bool legalLoc = false; + if (grounding) { + createRealBase = true; + // In a rare case it's possible to encounter an empty ungrounded loc that should be de-facto grounded. + // However, currently it's to set up its grounding due to limitations of the grounding algorithm that doesn't traverse diagonals. + // That's why we have to check adj locs on grounding to prevent adding incorrect locs and causing out of bounds exception. + legalLoc = activeColor == pla && !isGroundedOrWall(state, pla); + } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + legalLoc = true; // If no grounding, empty locations can be handled as well + } + + if (legalLoc) { + territoryLocationsBuffer.push_back(firstLoc); + setVisited(firstLoc); + walkStack.push_back(firstLoc); + } + + while (!walkStack.empty()) { + const Loc loc = walkStack.back(); + walkStack.pop_back(); + + forEachAdjacent(loc, [&](const Loc adj) { + if (isVisited(adj)) return; + + state = getState(adj); + activeColor = getActiveColor(state); + + bool isAdjLegal = false; + if (grounding) { + createRealBase = true; + isAdjLegal = activeColor == pla && !isGroundedOrWall(state, pla); + } else { + assert(activeColor != C_WALL); + if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + isAdjLegal = true; // If no grounding, empty locations can be handled as well + } + } + + if (isAdjLegal) { + territoryLocationsBuffer.push_back(adj); + setVisited(adj); + walkStack.push_back(adj); + } + }); + } + + clearVisited(territoryLocationsBuffer); +} + +Board::Base Board::createBaseAndUpdateStates(Player basePla, bool isReal) { + auto rollbackLocations = vector(); + rollbackLocations.reserve(territoryLocationsBuffer.size()); + auto rollbackStates = vector(); + rollbackStates.reserve(territoryLocationsBuffer.size()); + + for (const Loc& territoryLoc : territoryLocationsBuffer) { + State state = getState(territoryLoc); + + if (Player activePlayer = getActiveColor(state); activePlayer != basePla) { + State newState; + if (isReal) { + updateScoreAndHashForTerritory(territoryLoc, state, basePla, false); + newState = setTerritoryAndActivePlayer(state, basePla); + } else { + newState = static_cast(basePla << EMPTY_TERRITORY_SHIFT); + } + + rollbackLocations.push_back(territoryLoc); + rollbackStates.push_back(state); + setState(territoryLoc, newState); + } + } + + return {basePla, rollbackLocations, rollbackStates, isReal}; +} + +void Board::updateScoreAndHashForTerritory(const Loc loc, const State state, const Player basePla, const bool rollback) { + const Color currentColor = getActiveColor(state); + const Player baseOppPla = getOpp(basePla); + + if (isPlaced(state, baseOppPla)) { + // The `getTerritoryPositions` never returns positions inside already owned territory, + // so there is no need to check for the territory flag. + if (basePla == P_BLACK) { + if (!rollback) { + numWhiteCaptures++; + } else { + numWhiteCaptures--; + } + } else { + if (!rollback) { + numBlackCaptures++; + } else { + numBlackCaptures--; + } + } + } else if (isPlaced(state, basePla) && isActive(state, baseOppPla)) { + // No diff for the territory of the current player + if (basePla == P_BLACK) { + if (!rollback) { + numBlackCaptures--; + } else { + numBlackCaptures++; + } + } else { + if (!rollback) { + numWhiteCaptures--; + } else { + numWhiteCaptures++; + } + } + } + + if (currentColor == C_EMPTY) { + if (rules.multiStoneSuicideLegal) { + if (!rollback) { + numLegalMovesIfSuiAllowed--; + } else { + numLegalMovesIfSuiAllowed++; + } + } + pos_hash ^= ZOBRIST_BOARD_HASH[loc][basePla]; + } else if (currentColor == baseOppPla) { + // Simulate unmaking the opponent move and making the player's move + const auto positionsHash = ZOBRIST_BOARD_HASH[loc]; + pos_hash ^= positionsHash[baseOppPla]; + pos_hash ^= positionsHash[basePla]; + } +} + +void Board::invalidateAdjacentEmptyTerritoryIfNeeded(const Loc loc) { + walkStack.clear(); + walkStack.push_back(loc); + closureOrInvalidateLocsBuffer.clear(); + + while(!walkStack.empty()) { + const Loc lastLoc = walkStack.back(); + walkStack.pop_back(); + + FOREACHADJ( + Loc adj = lastLoc + ADJOFFSET; + + if (!isVisited(adj) && getEmptyTerritoryColor(getState(adj)) != C_EMPTY) { + closureOrInvalidateLocsBuffer.push_back(adj); + setState(adj, C_EMPTY); + setVisited(adj); + + walkStack.push_back(adj); + } + ) + } + + clearVisited(closureOrInvalidateLocsBuffer); +} + +void Board::makeMoveAndCalculateCapturesAndBases( + const Player pla, + const Loc loc, + vector& captures, + vector& bases) const { + if(isLegal(loc, pla, rules.multiStoneSuicideLegal, false)) { + MoveRecord moveRecord = const_cast(this)->playMoveRecordedDots(loc, pla); + + for(Base& base: moveRecord.bases) { + if (base.is_real) { + const bool suicide = base.pla != pla; + if (!suicide) { + captures[loc] = static_cast(captures[loc] | base.pla); + } + + for(const Loc& rollbackLoc: base.rollback_locations) { + // Consider empty bases as well + bases[rollbackLoc] = static_cast(bases[rollbackLoc] | base.pla); + } + } + } + + const_cast(this)->undo(moveRecord); + } +} + +void Board::calculateOneMoveCaptureAndBasePositionsForDots(vector& captures, vector& bases) const { + const int fieldSize = (x_size + 1) * (y_size + 1); + captures.resize(fieldSize); + bases.resize(fieldSize); + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + + const State state = getState(loc); + const Color emptyTerritoryColor = getEmptyTerritoryColor(state); + + // It doesn't make sense to calculate capturing when dot placed into own empty territory + if (emptyTerritoryColor != P_BLACK) { + makeMoveAndCalculateCapturesAndBases(P_BLACK, loc, captures, bases); + } + + if (emptyTerritoryColor != P_WHITE) { + makeMoveAndCalculateCapturesAndBases(P_WHITE, loc, captures, bases); + } + } + } +} diff --git a/cpp/game/graphhash.cpp b/cpp/game/graphhash.cpp index 3c5cb0a6a..32add1ffd 100644 --- a/cpp/game/graphhash.cpp +++ b/cpp/game/graphhash.cpp @@ -2,6 +2,7 @@ Hash128 GraphHash::getStateHash(const BoardHistory& hist, Player nextPlayer, double drawEquivalentWinsForWhite) { const Board& board = hist.getRecentBoard(0); + assert(hist.rules.isDots == board.isDots()); Hash128 hash = BoardHistory::getSituationRulesAndKoHash(board, hist, nextPlayer, drawEquivalentWinsForWhite); // Fold in whether a pass ends this phase @@ -12,11 +13,13 @@ Hash128 GraphHash::getStateHash(const BoardHistory& hist, Player nextPlayer, dou if(hist.isGameFinished) hash ^= Board::ZOBRIST_GAME_IS_OVER; - // Fold in consecutive pass count. Probably usually redundant with history tracking. Use some standard LCG constants. - static constexpr uint64_t CONSECPASS_MULT0 = 2862933555777941757ULL; - static constexpr uint64_t CONSECPASS_MULT1 = 3202034522624059733ULL; - hash.hash0 += CONSECPASS_MULT0 * (uint64_t)hist.consecutiveEndingPasses; - hash.hash1 += CONSECPASS_MULT1 * (uint64_t)hist.consecutiveEndingPasses; + if (!hist.rules.isDots) { + // Fold in consecutive pass count. Probably usually redundant with history tracking. Use some standard LCG constants. + static constexpr uint64_t CONSECPASS_MULT0 = 2862933555777941757ULL; + static constexpr uint64_t CONSECPASS_MULT1 = 3202034522624059733ULL; + hash.hash0 += CONSECPASS_MULT0 * (uint64_t)hist.consecutiveEndingPasses; + hash.hash1 += CONSECPASS_MULT1 * (uint64_t)hist.consecutiveEndingPasses; + } return hash; } @@ -43,8 +46,10 @@ Hash128 GraphHash::getGraphHashFromScratch(const BoardHistory& histOrig, Player Hash128 graphHash = Hash128(); for(size_t i = 0; i +#include "../core/rand.h" +#include "board.h" + using namespace std; using json = nlohmann::json; -Rules::Rules() { - //Defaults if not set - closest match to TT rules - koRule = KO_POSITIONAL; - scoringRule = SCORING_AREA; - taxRule = TAX_NONE; - multiStoneSuicideLegal = true; - hasButton = false; - whiteHandicapBonusRule = WHB_ZERO; - friendlyPassOk = false; - komi = 7.5f; -} +const Rules Rules::DEFAULT_DOTS = Rules(true); +const Rules Rules::DEFAULT_GO = Rules(false); + +Rules::Rules() : Rules(false) {} + +Rules::Rules(const int newStartPos, const bool newStartPosIsRandom, const bool newSuicide, const bool newDotsCaptureEmptyBases, const bool newDotsFreeCapturedDots) : + Rules(true, newStartPos, newStartPosIsRandom, 0, 0, 0, newSuicide, false, 0, false, 0.0f, newDotsCaptureEmptyBases, newDotsFreeCapturedDots) {} Rules::Rules( int kRule, @@ -28,60 +27,105 @@ Rules::Rules( int whbRule, bool pOk, float km +) : Rules(false, 0, false, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { +} + +Rules::Rules(const bool initIsDots) : Rules( + initIsDots, + initIsDots ? START_POS_CROSS : 0, + false, + initIsDots ? 0 : KO_POSITIONAL, + initIsDots ? 0 : SCORING_AREA, + initIsDots ? 0 : TAX_NONE, + initIsDots, + false, + initIsDots ? 0 : WHB_ZERO, + false, + initIsDots ? 0.0f : 7.5f, + false, + initIsDots + ) { +} + +Rules::Rules( + const bool newIsDots, + int startPosRule, + const bool newStartPosIsRandom, + int kRule, + int sRule, + int tRule, + bool suic, + bool button, + int whbRule, + bool pOk, + float km, + const bool newDotsCaptureEmptyBases, + const bool newDotsFreeCapturedDots ) - :koRule(kRule), - scoringRule(sRule), - taxRule(tRule), - multiStoneSuicideLegal(suic), - hasButton(button), - whiteHandicapBonusRule(whbRule), - friendlyPassOk(pOk), - komi(km) -{} + : isDots(newIsDots), + + startPos(startPosRule), + startPosIsRandom(newStartPosIsRandom), + komi(km), + multiStoneSuicideLegal(suic), + taxRule(tRule), + whiteHandicapBonusRule(whbRule), -Rules::~Rules() { + koRule(kRule), + scoringRule(sRule), + hasButton(button), + friendlyPassOk(pOk), + + dotsCaptureEmptyBases(newDotsCaptureEmptyBases), + dotsFreeCapturedDots(newDotsFreeCapturedDots) +{ + initializeIfNeeded(); } +Rules::~Rules() = default; + bool Rules::operator==(const Rules& other) const { - return - koRule == other.koRule && - scoringRule == other.scoringRule && - taxRule == other.taxRule && - multiStoneSuicideLegal == other.multiStoneSuicideLegal && - hasButton == other.hasButton && - whiteHandicapBonusRule == other.whiteHandicapBonusRule && - friendlyPassOk == other.friendlyPassOk && - komi == other.komi; + return equals(other, false); } bool Rules::operator!=(const Rules& other) const { - return - koRule != other.koRule || - scoringRule != other.scoringRule || - taxRule != other.taxRule || - multiStoneSuicideLegal != other.multiStoneSuicideLegal || - hasButton != other.hasButton || - whiteHandicapBonusRule != other.whiteHandicapBonusRule || - friendlyPassOk != other.friendlyPassOk || - komi != other.komi; + return !equals(other, false); } -bool Rules::equalsIgnoringKomi(const Rules& other) const { +bool Rules::equalsIgnoringSgfDefinedProps(const Rules& other) const { + return equals(other, true); +} + +bool Rules::equals(const Rules& other, const bool ignoreSgfDefinedProps) const { return + (ignoreSgfDefinedProps ? true : isDots == other.isDots) && + (ignoreSgfDefinedProps ? true : startPos == other.startPos) && + startPosIsRandom == other.startPosIsRandom && koRule == other.koRule && scoringRule == other.scoringRule && taxRule == other.taxRule && multiStoneSuicideLegal == other.multiStoneSuicideLegal && hasButton == other.hasButton && whiteHandicapBonusRule == other.whiteHandicapBonusRule && - friendlyPassOk == other.friendlyPassOk; + friendlyPassOk == other.friendlyPassOk && + (ignoreSgfDefinedProps ? true : komi == other.komi) && + dotsCaptureEmptyBases == other.dotsCaptureEmptyBases && + dotsFreeCapturedDots == other.dotsFreeCapturedDots; } bool Rules::gameResultWillBeInteger() const { - bool komiIsInteger = ((int)komi) == komi; + const bool komiIsInteger = Global::isEqual(std::floor(komi), komi); return komiIsInteger != hasButton; } +Rules Rules::getDefault(const bool isDots) { + return isDots ? DEFAULT_DOTS : DEFAULT_GO; +} + +Rules Rules::getDefaultOrTrompTaylorish(const bool isDots) { + return isDots ? DEFAULT_DOTS : getTrompTaylorish(); +} + Rules Rules::getTrompTaylorish() { Rules rules; rules.koRule = KO_POSITIONAL; @@ -112,6 +156,27 @@ bool Rules::komiIsIntOrHalfInt(float komi) { return std::isfinite(komi) && komi * 2 == (int)(komi * 2); } +set Rules::startPosStrings() { + initializeIfNeeded(); + return { + startPosIdToName[START_POS_EMPTY], + startPosIdToName[START_POS_CROSS], + startPosIdToName[START_POS_CROSS_2], + startPosIdToName[START_POS_CROSS_4], + }; +} + +int Rules::getNumOfStartPosStones() const { + switch (startPos) { + case START_POS_EMPTY: + return 0; + case START_POS_CROSS: return 4; + case START_POS_CROSS_2: return 8; + case START_POS_CROSS_4: return 16; + default: throw std::range_error("Invalid start pos: " + std::to_string(startPos)); + } +} + set Rules::koRuleStrings() { return {"SIMPLE","POSITIONAL","SITUATIONAL","SPIGHT"}; } @@ -125,6 +190,15 @@ set Rules::whiteHandicapBonusRuleStrings() { return {"0","N","N-1"}; } +int Rules::parseStartPos(const string& s) { + initializeIfNeeded(); + if (const auto it = startPosNameToId.find(s); it != startPosNameToId.end()) { + return it->second; + } + + throw IOError("Rules::parseStartPos: Invalid dots start pos rule: " + s); +} + int Rules::parseKoRule(const string& s) { if(s == "SIMPLE") return Rules::KO_SIMPLE; else if(s == "POSITIONAL") return Rules::KO_POSITIONAL; @@ -132,17 +206,20 @@ int Rules::parseKoRule(const string& s) { else if(s == "SPIGHT") return Rules::KO_SPIGHT; else throw IOError("Rules::parseKoRule: Invalid ko rule: " + s); } + int Rules::parseScoringRule(const string& s) { if(s == "AREA") return Rules::SCORING_AREA; else if(s == "TERRITORY") return Rules::SCORING_TERRITORY; else throw IOError("Rules::parseScoringRule: Invalid scoring rule: " + s); } + int Rules::parseTaxRule(const string& s) { if(s == "NONE") return Rules::TAX_NONE; else if(s == "SEKI") return Rules::TAX_SEKI; else if(s == "ALL") return Rules::TAX_ALL; else throw IOError("Rules::parseTaxRule: Invalid tax rule: " + s); } + int Rules::parseWhiteHandicapBonusRule(const string& s) { if(s == "0") return Rules::WHB_ZERO; else if(s == "N") return Rules::WHB_N; @@ -150,6 +227,15 @@ int Rules::parseWhiteHandicapBonusRule(const string& s) { else throw IOError("Rules::parseWhiteHandicapBonusRule: Invalid whiteHandicapBonus rule: " + s); } +string Rules::writeStartPosRule(int startPosRule) { + initializeIfNeeded(); + if (const auto it = startPosIdToName.find(startPosRule); it != startPosIdToName.end()) { + return it->second; + } + + return "UNKNOWN"; +} + string Rules::writeKoRule(int koRule) { if(koRule == Rules::KO_SIMPLE) return string("SIMPLE"); if(koRule == Rules::KO_POSITIONAL) return string("POSITIONAL"); @@ -157,6 +243,7 @@ string Rules::writeKoRule(int koRule) { if(koRule == Rules::KO_SPIGHT) return string("SPIGHT"); return string("UNKNOWN"); } + string Rules::writeScoringRule(int scoringRule) { if(scoringRule == Rules::SCORING_AREA) return string("AREA"); if(scoringRule == Rules::SCORING_TERRITORY) return string("TERRITORY"); @@ -176,38 +263,37 @@ string Rules::writeWhiteHandicapBonusRule(int whiteHandicapBonusRule) { } ostream& operator<<(ostream& out, const Rules& rules) { - out << "ko" << Rules::writeKoRule(rules.koRule) - << "score" << Rules::writeScoringRule(rules.scoringRule) - << "tax" << Rules::writeTaxRule(rules.taxRule) - << "sui" << rules.multiStoneSuicideLegal; - if(rules.hasButton) - out << "button" << rules.hasButton; - if(rules.whiteHandicapBonusRule != Rules::WHB_ZERO) - out << "whb" << Rules::writeWhiteHandicapBonusRule(rules.whiteHandicapBonusRule); - if(rules.friendlyPassOk) - out << "fpok" << rules.friendlyPassOk; - out << "komi" << rules.komi; + out << rules.toString(); return out; } -string Rules::toStringNoKomi() const { +string Rules::toString(const bool includeSgfDefinedProperties) const { ostringstream out; - out << "ko" << Rules::writeKoRule(koRule) - << "score" << Rules::writeScoringRule(scoringRule) - << "tax" << Rules::writeTaxRule(taxRule) - << "sui" << multiStoneSuicideLegal; - if(hasButton) - out << "button" << hasButton; - if(whiteHandicapBonusRule != WHB_ZERO) - out << "whb" << Rules::writeWhiteHandicapBonusRule(whiteHandicapBonusRule); - if(friendlyPassOk) - out << "fpok" << friendlyPassOk; - return out.str(); -} - -string Rules::toString() const { - ostringstream out; - out << (*this); + if (!isDots) { + out << "ko" << writeKoRule(koRule) + << "score" << writeScoringRule(scoringRule) + << "tax" << writeTaxRule(taxRule); + } else { + out << DOTS_CAPTURE_EMPTY_BASE_KEY << dotsCaptureEmptyBases; + } + if (includeSgfDefinedProperties && startPos != START_POS_EMPTY) { + out << START_POS_KEY << writeStartPosRule(startPos); + } + if (startPosIsRandom) { + out << START_POS_RANDOM_KEY << startPosIsRandom; + } + out << "sui" << multiStoneSuicideLegal; + if (!isDots) { + if (hasButton != DEFAULT_GO.hasButton) + out << "button" << hasButton; + if (whiteHandicapBonusRule != DEFAULT_GO.whiteHandicapBonusRule) + out << "whb" << writeWhiteHandicapBonusRule(whiteHandicapBonusRule); + if (friendlyPassOk != DEFAULT_GO.friendlyPassOk) + out << "fpok" << friendlyPassOk; + } + if (includeSgfDefinedProperties) { + out << "komi" << komi; + } return out.str(); } @@ -215,16 +301,27 @@ string Rules::toString() const { //which is the default for parsing and if not otherwise specified json Rules::toJsonHelper(bool omitKomi, bool omitDefaults) const { json ret; - ret["ko"] = writeKoRule(koRule); - ret["scoring"] = writeScoringRule(scoringRule); - ret["tax"] = writeTaxRule(taxRule); + if (isDots) + ret[DOTS_KEY] = true; + if (!omitDefaults || startPos != START_POS_EMPTY) + ret[START_POS_KEY] = writeStartPosRule(startPos); ret["suicide"] = multiStoneSuicideLegal; - if(!omitDefaults || hasButton) - ret["hasButton"] = hasButton; - if(!omitDefaults || whiteHandicapBonusRule != WHB_ZERO) - ret["whiteHandicapBonus"] = writeWhiteHandicapBonusRule(whiteHandicapBonusRule); - if(!omitDefaults || friendlyPassOk != false) - ret["friendlyPassOk"] = friendlyPassOk; + if (!isDots) { + ret["ko"] = writeKoRule(koRule); + ret["scoring"] = writeScoringRule(scoringRule); + ret["tax"] = writeTaxRule(taxRule); + if(!omitDefaults || hasButton != DEFAULT_GO.hasButton) + ret["hasButton"] = hasButton; + if(!omitDefaults || whiteHandicapBonusRule != DEFAULT_GO.whiteHandicapBonusRule) + ret["whiteHandicapBonus"] = writeWhiteHandicapBonusRule(whiteHandicapBonusRule); + if(!omitDefaults || friendlyPassOk != DEFAULT_GO.friendlyPassOk) + ret["friendlyPassOk"] = friendlyPassOk; + } else { + if (!omitDefaults || dotsCaptureEmptyBases != DEFAULT_DOTS.dotsCaptureEmptyBases) + ret[DOTS_CAPTURE_EMPTY_BASE_KEY] = dotsCaptureEmptyBases; + if (!omitDefaults || startPosIsRandom != DEFAULT_DOTS.startPosIsRandom) + ret[START_POS_RANDOM_KEY] = startPosIsRandom; + } if(!omitKomi) ret["komi"] = komi; return ret; @@ -254,26 +351,48 @@ string Rules::toJsonStringNoKomiMaybeOmitStuff() const { return toJsonHelper(true,true).dump(); } -Rules Rules::updateRules(const string& k, const string& v, Rules oldRules) { +Rules Rules::updateRules(const string& k, const string& v, const Rules& oldRules) { Rules rules = oldRules; - string key = Global::trim(k); - string value = Global::trim(Global::toUpper(v)); - if(key == "ko") rules.koRule = Rules::parseKoRule(value); - else if(key == "score") rules.scoringRule = Rules::parseScoringRule(value); - else if(key == "scoring") rules.scoringRule = Rules::parseScoringRule(value); - else if(key == "tax") rules.taxRule = Rules::parseTaxRule(value); - else if(key == "suicide") rules.multiStoneSuicideLegal = Global::stringToBool(value); - else if(key == "hasButton") rules.hasButton = Global::stringToBool(value); - else if(key == "whiteHandicapBonus") rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(value); - else if(key == "friendlyPassOk") rules.friendlyPassOk = Global::stringToBool(value); - else throw IOError("Unknown rules option: " + key); + const string key = Global::trim(k); + const string value = Global::trim(Global::toUpper(v)); + + bool parsed = true; + + if (key == DOTS_KEY) rules.isDots = Global::stringToBool(value); + else if (key == "suicide") rules.multiStoneSuicideLegal = Global::stringToBool(value); + else if (key == START_POS_KEY) rules.startPos = parseStartPos(value); + else if (key == START_POS_RANDOM_KEY) rules.startPosIsRandom = Global::stringToBool(value); + else parsed = false; + + if (!parsed) { + parsed = true; + if (!rules.isDots) { + if(key == "ko") rules.koRule = parseKoRule(value); + else if(key == "score" || key == "scoring") rules.scoringRule = parseScoringRule(value); + else if(key == "tax") rules.taxRule = parseTaxRule(value); + else if(key == "hasButton") rules.hasButton = Global::stringToBool(value); + else if(key == "whiteHandicapBonus") rules.whiteHandicapBonusRule = parseWhiteHandicapBonusRule(value); + else if(key == "friendlyPassOk") rules.friendlyPassOk = Global::stringToBool(value); + else parsed = false; + } else { + if (key == DOTS_CAPTURE_EMPTY_BASE_KEY) rules.dotsCaptureEmptyBases = Global::stringToBool(value); + else parsed = false; + } + } + + if (!parsed) { + throw IOError("Unknown rules option: " + key); + } + return rules; } -static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { - Rules rules; +static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) { + auto rules = Rules(isDots); string lowercased = Global::trim(Global::toLower(sOrig)); - if(lowercased == "japanese" || lowercased == "korean") { + if(lowercased == DOTS_KEY) { + rules = Rules::DEFAULT_DOTS; + } else if(lowercased == "japanese" || lowercased == "korean") { rules.scoringRule = Rules::SCORING_TERRITORY; rules.koRule = Rules::KO_SIMPLE; rules.taxRule = Rules::TAX_SEKI; @@ -379,21 +498,27 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { rules.friendlyPassOk = true; rules.komi = 7.5; } - else if(sOrig.length() > 0 && sOrig[0] == '{') { + else if(!sOrig.empty() && sOrig[0] == '{') { //Default if not specified - rules = Rules::getTrompTaylorish(); + rules = Rules::getDefaultOrTrompTaylorish(isDots); bool komiSpecified = false; bool taxSpecified = false; try { json input = json::parse(sOrig); string s; for(json::iterator iter = input.begin(); iter != input.end(); ++iter) { - string key = iter.key(); - if(key == "ko") + const string& key = iter.key(); + if (key == DOTS_KEY) + rules.isDots = iter.value().get(); + else if (key == START_POS_KEY) + rules.startPos = Rules::parseStartPos(iter.value().get()); + else if (key == START_POS_RANDOM_KEY) + rules.startPosIsRandom = iter.value().get(); + else if (key == DOTS_CAPTURE_EMPTY_BASE_KEY) + rules.dotsCaptureEmptyBases = iter.value().get(); + else if(key == "ko") rules.koRule = Rules::parseKoRule(iter.value().get()); - else if(key == "score") - rules.scoringRule = Rules::parseScoringRule(iter.value().get()); - else if(key == "scoring") + else if(key == "score" || key == "scoring") rules.scoringRule = Rules::parseScoringRule(iter.value().get()); else if(key == "tax") { rules.taxRule = Rules::parseTaxRule(iter.value().get()); taxSpecified = true; @@ -443,7 +568,7 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { }; //Default if not specified - rules = Rules::getTrompTaylorish(); + rules = Rules::getDefaultOrTrompTaylorish(isDots); string s = sOrig; s = Global::trim(s); @@ -528,6 +653,18 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { else throw IOError("Could not parse rules: " + sOrig); continue; } + if(startsWithAndStrip(s,DOTS_CAPTURE_EMPTY_BASE_KEY)) { + if(startsWithAndStrip(s,"1")) rules.dotsCaptureEmptyBases = true; + else if(startsWithAndStrip(s,"0")) rules.dotsCaptureEmptyBases = false; + else throw IOError("Could not parse rules: " + sOrig); + continue; + } + if(startsWithAndStrip(s,START_POS_RANDOM_KEY)) { + if(startsWithAndStrip(s,"1")) rules.startPosIsRandom = true; + else if(startsWithAndStrip(s,"0")) rules.startPosIsRandom = false; + else throw IOError("Could not parse rules: " + sOrig); + continue; + } //Unknown rules format else throw IOError("Could not parse rules: " + sOrig); @@ -546,49 +683,309 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { return rules; } -Rules Rules::parseRules(const string& sOrig) { - return parseRulesHelper(sOrig,true); +Rules Rules::parseRules(const string& sOrig, const bool isDots) { + return parseRulesHelper(sOrig,true,isDots); } -Rules Rules::parseRulesWithoutKomi(const string& sOrig, float komi) { - Rules rules = parseRulesHelper(sOrig,false); + +Rules Rules::parseRulesWithoutKomi(const string& sOrig, const float komi, const bool isDots) { + Rules rules = parseRulesHelper(sOrig,false,isDots); rules.komi = komi; return rules; } -bool Rules::tryParseRules(const string& sOrig, Rules& buf) { +bool Rules::tryParseRules(const string& sOrig, Rules& buf, bool isDots) { Rules rules; - try { rules = parseRulesHelper(sOrig,true); } + try { rules = parseRulesHelper(sOrig,true,isDots); } catch(const StringError&) { return false; } buf = rules; return true; } -bool Rules::tryParseRulesWithoutKomi(const string& sOrig, Rules& buf, float komi) { + +bool Rules::tryParseRulesWithoutKomi(const string& sOrig, Rules& buf, float komi, bool isDots) { Rules rules; - try { rules = parseRulesHelper(sOrig,false); } + try { rules = parseRulesHelper(sOrig,false,isDots); } catch(const StringError&) { return false; } rules.komi = komi; buf = rules; return true; } -string Rules::toStringNoKomiMaybeNice() const { - if(equalsIgnoringKomi(parseRulesHelper("TrompTaylor",false))) +string Rules::toStringNoSgfDefinedPropertiesMaybeNice() const { + if(equalsIgnoringSgfDefinedProps(parseRulesHelper(DOTS_KEY, false, isDots))) + return "Dots"; + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("TrompTaylor",false, isDots))) return "TrompTaylor"; - if(equalsIgnoringKomi(parseRulesHelper("Japanese",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Japanese",false, isDots))) return "Japanese"; - if(equalsIgnoringKomi(parseRulesHelper("Chinese",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Chinese",false, isDots))) return "Chinese"; - if(equalsIgnoringKomi(parseRulesHelper("Chinese-OGS",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Chinese-OGS",false, isDots))) return "Chinese-OGS"; - if(equalsIgnoringKomi(parseRulesHelper("AGA",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("AGA",false, isDots))) return "AGA"; - if(equalsIgnoringKomi(parseRulesHelper("StoneScoring",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("StoneScoring",false, isDots))) return "StoneScoring"; - if(equalsIgnoringKomi(parseRulesHelper("NewZealand",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("NewZealand",false, isDots))) return "NewZealand"; - return toStringNoKomi(); + return toString(false); +} + +static double nextRandomOffset(Rand& rand) { + return rand.nextDouble(4, 7); +} + +static int nextRandomOffsetX(Rand& rand, int x_size) { + return static_cast(round(nextRandomOffset(rand) / 39.0 * x_size)); +} + +static int nextRandomOffsetY(Rand& rand, int y_size) { + return static_cast(round(nextRandomOffset(rand) / 32.0 * y_size)); +} + +std::vector Rules::generateStartPos(const int startPos, Rand* rand, const int x_size, const int y_size) { + std::vector moves; + switch (startPos) { + case START_POS_EMPTY: + break; + case START_POS_CROSS: + if (x_size >= 2 && y_size >= 2) { + // Obey notago implementation for odd height + addCross((x_size + 1) / 2 - 1, y_size / 2 - 1, x_size, false, moves); + } + break; + case START_POS_CROSS_2: + if (x_size >= 4 && y_size >= 2) { + const int middleX = (x_size + 1) / 2; + const int middleY = y_size / 2 - 1; // Obey notago implementation for odd height + addCross(middleX - 2, middleY, x_size, false, moves); + addCross(middleX, middleY, x_size, true, moves); + } + break; + case START_POS_CROSS_4: + if (x_size >= 4 && y_size >= 4) { + int offsetX1; + int offsetY1; + int offsetX2; + int offsetY2; + int offsetX3; + int offsetY3; + int offsetX4; + int offsetY4; + + if (rand == nullptr) { + if (x_size == 39 && y_size == 32) { + offsetX1 = 11; + offsetY1 = 10; + } else { + offsetX1 = (x_size - 3) / 3; + offsetY1 = (y_size - 3) / 3; + } + // Consider the index and size of the cross + const int sideOffsetX = x_size - 1 - (offsetX1 + 1); + const int sideOffsetY = y_size - 1 - (offsetY1 + 1); + + offsetX2 = sideOffsetX; + offsetY2 = offsetY1; + offsetX3 = sideOffsetX; + offsetY3 = sideOffsetY; + offsetX4 = offsetX1; + offsetY4 = sideOffsetY; + } else { + const int middleX = x_size / 2 - 1; + const int middleY = y_size / 2 - 1; + + offsetX1 = middleX - nextRandomOffsetX(*rand, x_size); + offsetY1 = middleY - nextRandomOffsetY(*rand, y_size); + offsetX2 = middleX + nextRandomOffsetX(*rand, x_size); + offsetY2 = middleY - nextRandomOffsetY(*rand, y_size); + offsetX3 = middleX + nextRandomOffsetX(*rand, x_size); + offsetY3 = middleY + nextRandomOffsetY(*rand, y_size); + offsetX4 = middleX - nextRandomOffsetX(*rand, x_size); + offsetY4 = middleY + nextRandomOffsetY(*rand, y_size); + } + + addCross(offsetX1, offsetY1, x_size, false, moves); + addCross(offsetX2, offsetY2, x_size, false, moves); + addCross(offsetX3, offsetY3, x_size, false, moves); + addCross(offsetX4, offsetY4, x_size, false, moves); + } + break; + default: + throw std::range_error("Unsupported " + START_POS_KEY + ": " + to_string(startPos)); + } + + return moves; +} + +void Rules::addCross(const int x, const int y, const int x_size, const bool rotate90, std::vector& moves) { + Player pla; + Player opp; + + // The end move should always be P_WHITE to keep a consistent moves order + if (!rotate90) { + pla = P_BLACK; + opp = P_WHITE; + } else { + pla = P_WHITE; + opp = P_BLACK; + } + + const auto tailMove = Move(Location::getLoc(x, y, x_size), pla); + + if (!rotate90) { + moves.push_back(tailMove); + } + + moves.emplace_back(Location::getLoc(x + 1, y, x_size), opp); + moves.emplace_back(Location::getLoc(x + 1, y + 1, x_size), pla); + moves.emplace_back(Location::getLoc(x, y + 1, x_size), opp); + + if (rotate90) { + moves.push_back(tailMove); + } } +int Rules::recognizeStartPos( + const vector& placementMoves, + const int x_size, + const int y_size, + vector& startPosMoves, + bool& randomized, + vector* remainingMoves) { + + startPosMoves = vector(); + randomized = false; // Empty or unknown start pos is static by default + if (remainingMoves != nullptr) { + *remainingMoves = placementMoves; + } + int resultStartPos = START_POS_EMPTY; + + if(placementMoves.empty()) return resultStartPos; + + const int stride = x_size + 1; + auto placement = vector(stride * (y_size + 2), C_EMPTY); + + for (const auto move : placementMoves) { + placement[move.loc] = move.pla; + } + + for (const auto move : placementMoves) { + const int x = Location::getX(move.loc, x_size); + const int y = Location::getY(move.loc, x_size); + + const int xy = Location::getLoc(x, y, x_size); + const Player pla = placement[xy]; + if (pla == C_EMPTY) continue; + + const Player opp = getOpp(pla); + if (opp == C_EMPTY) continue; + + if (x + 1 > x_size) continue; + const int xp1y = Location::getLoc(x + 1, y, x_size); + if (placement[xp1y] != opp) continue; + + if (y + 1 > y_size) continue; + const int xp1yp1 = Location::getLoc(x + 1, y + 1, x_size); + if (placement[xp1yp1] != pla) continue; + + const int xyp1 = Location::getLoc(x, y + 1, x_size); + if (placement[xyp1] != opp) continue; + + // Match order of generator (white move is always last) + if (pla == P_BLACK) { + startPosMoves.emplace_back(xy, pla); + } + + startPosMoves.emplace_back(xp1y, opp); + startPosMoves.emplace_back(xp1yp1, pla); + startPosMoves.emplace_back(xyp1, opp); + + if (pla != P_BLACK) { + startPosMoves.emplace_back(xy, pla); + } + + // Clear the placement because the recognized cross is already stored + placement[xy] = C_EMPTY; + placement[xp1y] = C_EMPTY; + placement[xp1yp1] = C_EMPTY; + placement[xyp1] = C_EMPTY; + } + + // Try to match strictly and set up randomized if failed. + // Also, refine start pos and remaining moves. + auto finishRecognition = [&](const int expectedStartPos) -> void { + auto staticStartPosMoves = generateStartPos(expectedStartPos, nullptr, x_size, y_size); + + assert(remainingMoves != nullptr || placementMoves.size() == startPosMoves.size()); + const auto startPosMovesSize = staticStartPosMoves.size(); + assert(startPosMovesSize <= startPosMoves.size()); + + vector refinedStartPosMoves; + + // TODO: generalize (currently it works only for crosses) + for(int startPosInd = 0; startPosInd < startPosMoves.size(); startPosInd += 4) { + if (startPosInd + 3 >= startPosMoves.size()) continue; + for (int staticStartPosInd = 0; staticStartPosInd < staticStartPosMoves.size(); staticStartPosInd += 4) { + if (staticStartPosInd + 3 >= staticStartPosMoves.size()) continue; + + if (!movesEqual(startPosMoves[startPosInd], staticStartPosMoves[staticStartPosInd])) continue; + if (!movesEqual(startPosMoves[startPosInd + 1], staticStartPosMoves[staticStartPosInd + 1])) continue; + if (!movesEqual(startPosMoves[startPosInd + 2], staticStartPosMoves[staticStartPosInd + 2])) continue; + if (!movesEqual(startPosMoves[startPosInd + 3], staticStartPosMoves[staticStartPosInd + 3])) continue; + + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 1]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 2]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 3]); + + staticStartPosMoves.erase(staticStartPosMoves.begin() + staticStartPosInd, staticStartPosMoves.begin() + staticStartPosInd + 4); + break; + } + } + + for (const auto recognizedMove : startPosMoves) { + if (refinedStartPosMoves.size() == startPosMovesSize) { + break; + } + if (!std::any_of(refinedStartPosMoves.begin(), refinedStartPosMoves.end(), [&](const Move& m) { + return movesEqual(m, recognizedMove); + })) { + refinedStartPosMoves.emplace_back(recognizedMove); + } + } + + startPosMoves = refinedStartPosMoves; + + if (remainingMoves != nullptr) { + for (const auto recognizedMove : startPosMoves) { + [[maybe_unused]] bool remainingMoveIsRemoved = false; + for(auto it = remainingMoves->begin(); it != remainingMoves->end(); ++it) { + if (movesEqual(*it, recognizedMove)) { + remainingMoves->erase(it); + remainingMoveIsRemoved = true; + break; + } + } + assert(remainingMoveIsRemoved); + } + } + + randomized = !staticStartPosMoves.empty(); + resultStartPos = expectedStartPos; + }; + + if (const auto recognizedCrossesMovesSize = startPosMoves.size(); recognizedCrossesMovesSize < 4) { + finishRecognition(START_POS_EMPTY); + } else if (recognizedCrossesMovesSize < 8) { + finishRecognition(START_POS_CROSS); + } else if (recognizedCrossesMovesSize < 16) { + finishRecognition(START_POS_CROSS_2); + } else { + finishRecognition(START_POS_CROSS_4); + } + + return resultStartPos; +} const Hash128 Rules::ZOBRIST_KO_RULE_HASH[4] = { Hash128(0x3cc7e0bf846820f6ULL, 0x1fb7fbde5fc6ba4eULL), //Based on sha256 hash of Rules::KO_SIMPLE @@ -618,3 +1015,9 @@ const Hash128 Rules::ZOBRIST_BUTTON_HASH = //Based on sha256 hash of Rules::ZO const Hash128 Rules::ZOBRIST_FRIENDLY_PASS_OK_HASH = //Based on sha256 hash of Rules::ZOBRIST_FRIENDLY_PASS_OK_HASH Hash128(0x0113655998ef0a25ULL, 0x99c9d04ecd964874ULL); +const Hash128 Rules::ZOBRIST_DOTS_GAME_HASH = + Hash128(0xcdbfab9c91da83a9ULL, 0x6c2f198b2742181full); + +const Hash128 Rules::ZOBRIST_DOTS_CAPTURE_EMPTY_BASES_HASH = + Hash128(0x469afde424e960deull, 0x59f6138cebc753afull); + diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 1fffab834..343f2b936 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -1,47 +1,68 @@ #ifndef GAME_RULES_H_ #define GAME_RULES_H_ +#include "common.h" #include "../core/global.h" #include "../core/hash.h" +#include "../core/rand.h" #include "../external/nlohmann_json/json.hpp" struct Rules { + const static Rules DEFAULT_DOTS; + const static Rules DEFAULT_GO; - static const int KO_SIMPLE = 0; - static const int KO_POSITIONAL = 1; - static const int KO_SITUATIONAL = 2; - static const int KO_SPIGHT = 3; - int koRule; + bool isDots; - static const int SCORING_AREA = 0; - static const int SCORING_TERRITORY = 1; - int scoringRule; + static constexpr int START_POS_EMPTY = 0; + static constexpr int START_POS_CROSS = 1; + static constexpr int START_POS_CROSS_2 = 2; + static constexpr int START_POS_CROSS_4 = 3; + int startPos; + + // Enables random shuffling of start pos. Currently, it works only for CROSS_4 + bool startPosIsRandom; + + float komi; + //Min and max acceptable komi in various places involving user input validation + static constexpr float MIN_USER_KOMI = -150.0f; + static constexpr float MAX_USER_KOMI = 150.0f; + bool multiStoneSuicideLegal; // Works as just suicide in Dots Game static const int TAX_NONE = 0; static const int TAX_SEKI = 1; static const int TAX_ALL = 2; int taxRule; - bool multiStoneSuicideLegal; - bool hasButton; - static const int WHB_ZERO = 0; static const int WHB_N = 1; static const int WHB_N_MINUS_ONE = 2; int whiteHandicapBonusRule; + static const int KO_SIMPLE = 0; + static const int KO_POSITIONAL = 1; + static const int KO_SITUATIONAL = 2; + static const int KO_SPIGHT = 3; + int koRule; + + static const int SCORING_AREA = 0; + static const int SCORING_TERRITORY = 1; + int scoringRule; + + bool hasButton; + //Mostly an informational value - doesn't affect the actual implemented rules, but GTP or Analysis may, at a //high level, use this info to adjust passing behavior - whether it's okay to pass without capturing dead stones. //Only relevant for area scoring. bool friendlyPassOk; - float komi; - //Min and max acceptable komi in various places involving user input validation - static constexpr float MIN_USER_KOMI = -150.0f; - static constexpr float MAX_USER_KOMI = 150.0f; + bool dotsCaptureEmptyBases; + bool dotsFreeCapturedDots; // TODO: Implement later Rules(); + // Constructor for Dots + Rules(int newStartPos, bool newStartPosIsRandom, bool newSuicide, bool newDotsCaptureEmptyBases, bool newDotsFreeCapturedDots); + // Constructor for Go Rules( int koRule, int scoringRule, @@ -52,14 +73,18 @@ struct Rules { bool friendlyPassOk, float komi ); + explicit Rules(bool initIsDots); ~Rules(); bool operator==(const Rules& other) const; bool operator!=(const Rules& other) const; - bool equalsIgnoringKomi(const Rules& other) const; - bool gameResultWillBeInteger() const; + [[nodiscard]] bool equalsIgnoringSgfDefinedProps(const Rules& other) const; + [[nodiscard]] bool equals(const Rules& other, bool ignoreSgfDefinedProps) const; + [[nodiscard]] bool gameResultWillBeInteger() const; + static Rules getDefault(bool isDots); + static Rules getDefaultOrTrompTaylorish(bool isDots); static Rules getTrompTaylorish(); static Rules getSimpleTerritory(); @@ -67,28 +92,51 @@ struct Rules { static std::set scoringRuleStrings(); static std::set taxRuleStrings(); static std::set whiteHandicapBonusRuleStrings(); + static int parseStartPos(const std::string& s); static int parseKoRule(const std::string& s); static int parseScoringRule(const std::string& s); static int parseTaxRule(const std::string& s); static int parseWhiteHandicapBonusRule(const std::string& s); + static std::string writeStartPosRule(int startPosRule); static std::string writeKoRule(int koRule); static std::string writeScoringRule(int scoringRule); static std::string writeTaxRule(int taxRule); static std::string writeWhiteHandicapBonusRule(int whiteHandicapBonusRule); static bool komiIsIntOrHalfInt(float komi); - - static Rules parseRules(const std::string& str); - static Rules parseRulesWithoutKomi(const std::string& str, float komi); - static bool tryParseRules(const std::string& str, Rules& buf); - static bool tryParseRulesWithoutKomi(const std::string& str, Rules& buf, float komi); - - static Rules updateRules(const std::string& key, const std::string& value, Rules priorRules); + static std::set startPosStrings(); + int getNumOfStartPosStones() const; + + static Rules parseRules(const std::string& sOrig, bool isDots = false); + static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi, bool isDots = false); + static bool tryParseRules(const std::string& sOrig, Rules& buf, bool isDots); + static bool tryParseRulesWithoutKomi(const std::string& sOrig, Rules& buf, float komi, bool isDots); + + static Rules updateRules(const std::string& key, const std::string& value, const Rules& oldRules); + + static std::vector generateStartPos(int startPos, Rand* rand, int x_size, int y_size); + /** + * @param placementMoves placement moves that we are trying to recognize. + * @param x_size size of field + * @param y_size size of field + * @param startPosMoves moves for a recognized pattern. It's empty if recognition is failed. + * @param randomized if we recognize a start pos, but it doesn't match the strict position, set it up to `true` + * @param remainingMoves represents moves that remain after start pos recognition (@param placementMoves - @param startPosMoves), + * If it's null (default), then it's assumed that all placement moves are used in the recognized start pos. + * @return recognized type of start pos. If the @param placementMoves don't match any known patter, then + * it returns empty pos with @param remainingMoves == @param placementMoves + */ + static int recognizeStartPos( + const std::vector& placementMoves, + int x_size, + int y_size, + std::vector& startPosMoves, + bool& randomized, + std::vector* remainingMoves = nullptr); friend std::ostream& operator<<(std::ostream& out, const Rules& rules); - std::string toString() const; - std::string toStringNoKomi() const; - std::string toStringNoKomiMaybeNice() const; + std::string toString(bool includeSgfDefinedProperties = true) const; + std::string toStringNoSgfDefinedPropertiesMaybeNice() const; std::string toJsonString() const; std::string toJsonStringNoKomi() const; std::string toJsonStringNoKomiMaybeOmitStuff() const; @@ -102,8 +150,45 @@ struct Rules { static const Hash128 ZOBRIST_MULTI_STONE_SUICIDE_HASH; static const Hash128 ZOBRIST_BUTTON_HASH; static const Hash128 ZOBRIST_FRIENDLY_PASS_OK_HASH; + static const Hash128 ZOBRIST_DOTS_GAME_HASH; + static const Hash128 ZOBRIST_DOTS_CAPTURE_EMPTY_BASES_HASH; private: + // General constructor + Rules( + bool newIsDots, + int startPosRule, + bool newStartPosIsRandom, + int kRule, + int sRule, + int tRule, + bool suic, + bool button, + int whbRule, + bool pOk, + float km, + bool newDotsCaptureEmptyBases, + bool newDotsFreeCapturedDots + ); + + static inline std::map startPosIdToName; + static inline std::map startPosNameToId; + + static void initializeIfNeeded() { + if (startPosIdToName.empty()) { + startPosIdToName[START_POS_EMPTY] = "EMPTY"; + startPosIdToName[START_POS_CROSS] = "CROSS"; + startPosIdToName[START_POS_CROSS_2] = "CROSS_2"; + startPosIdToName[START_POS_CROSS_4] = "CROSS_4"; + startPosNameToId["EMPTY"] = START_POS_EMPTY; + startPosNameToId["CROSS"] = START_POS_CROSS; + startPosNameToId["CROSS_2"] = START_POS_CROSS_2; + startPosNameToId["CROSS_4"] = START_POS_CROSS_4; + } + } + + static void addCross(int x, int y, int x_size, bool rotate90, std::vector& moves); + nlohmann::json toJsonHelper(bool omitKomi, bool omitDefaults) const; }; diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 5ae003cec..fa40382e5 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -1557,6 +1557,9 @@ ModelPostProcessParams::~ModelPostProcessParams() ModelDesc::ModelDesc() : modelVersion(-1), + maxLenX(Board::MAX_LEN_X), + maxLenY(Board::MAX_LEN_Y), + isDotsGame(false), numInputChannels(0), numInputGlobalChannels(0), numInputMetaChannels(0), @@ -1564,14 +1567,31 @@ ModelDesc::ModelDesc() numValueChannels(0), numScoreValueChannels(0), numOwnershipChannels(0), - metaEncoderVersion(0), - postProcessParams() -{} + metaEncoderVersion(0) {} -ModelDesc::ModelDesc(istream& in, const string& sha256_, bool binaryFloats) { +ModelDesc::ModelDesc(istream& in, const string& sha256_, const bool binaryFloats) { in >> name; sha256 = sha256_; - in >> modelVersion; + + string str; + in >> str; + + if (const auto pieces = Global::split(str, ','); pieces.size() > 1) { + modelVersion = Global::stringToInt(pieces[0]); + maxLenX = Global::stringToInt(pieces[1]); + maxLenY = Global::stringToInt(pieces[2]); + const auto gamePieces = Global::split(pieces[3], ';'); + if (gamePieces.size() > 1) { + throw StringError("KataGo currently supports only single-game mode. The `" + name + "` is a mixed model for games: " + Global::concat(gamePieces, ", ")); + } + isDotsGame = Global::isEqualCaseInsensitive(gamePieces[0], DOTS_KEY); + } else { + modelVersion = Global::stringToInt(str); + maxLenX = -1; // It's unclear the exact training size of old modules. Typically, it's 19, but not always + maxLenY = -1; + isDotsGame = false; + } + if(in.fail()) throw StringError("Model failed to parse name or version. Is this a valid model file? You probably specified the wrong file."); @@ -1722,10 +1742,13 @@ ModelDesc::ModelDesc(ModelDesc&& other) { *this = std::move(other); } -ModelDesc& ModelDesc::operator=(ModelDesc&& other) { +ModelDesc& ModelDesc::operator=(ModelDesc&& other) noexcept { name = std::move(other.name); sha256 = std::move(other.sha256); modelVersion = other.modelVersion; + maxLenX = other.maxLenX; + maxLenY = other.maxLenY; + isDotsGame = other.isDotsGame; numInputChannels = other.numInputChannels; numInputGlobalChannels = other.numInputGlobalChannels; numInputMetaChannels = other.numInputMetaChannels; diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index b44d00c59..993b39557 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -347,6 +347,9 @@ struct ModelDesc { std::string name; std::string sha256; int modelVersion; + int maxLenX; // Currently unused, but initialize it just in case for the future + int maxLenY; + bool isDotsGame; int numInputChannels; int numInputGlobalChannels; int numInputMetaChannels; @@ -365,13 +368,13 @@ struct ModelDesc { ModelDesc(); ~ModelDesc(); - ModelDesc(std::istream& in, const std::string& sha256, bool binaryFloats); + ModelDesc(std::istream& in, const std::string& sha256_, bool binaryFloats); ModelDesc(ModelDesc&& other); ModelDesc(const ModelDesc&) = delete; ModelDesc& operator=(const ModelDesc&) = delete; - ModelDesc& operator=(ModelDesc&& other); + ModelDesc& operator=(ModelDesc&& other) noexcept ; void iterConvLayers(std::function f) const; int maxConvChannels(int convXSize, int convYSize) const; diff --git a/cpp/neuralnet/modelversion.cpp b/cpp/neuralnet/modelversion.cpp index e8ece86b1..4770da969 100644 --- a/cpp/neuralnet/modelversion.cpp +++ b/cpp/neuralnet/modelversion.cpp @@ -22,60 +22,115 @@ //15 = V7 features, Extra nonlinearity for pass output //16 = V7 features, Q value predictions in the policy head -static void fail(int modelVersion) { - throw StringError("NNModelVersion: Model version not currently implemented or supported: " + Global::intToString(modelVersion)); +static void fail(const int modelVersion, const bool dotsGame) { + throw StringError("NNModelVersion: Model version not currently implemented or supported: " + Global::intToString(modelVersion) + + (dotsGame ? " (Dots game)" : "")); } -static_assert(NNModelVersion::oldestModelVersionImplemented == 3, ""); -static_assert(NNModelVersion::oldestInputsVersionImplemented == 3, ""); -static_assert(NNModelVersion::latestModelVersionImplemented == 16, ""); -static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); +static_assert(NNModelVersion::oldestModelVersionImplemented == 3); +static_assert(NNModelVersion::oldestInputsVersionImplemented == 3); +static_assert(NNModelVersion::latestModelVersionImplemented == 16); +static_assert(NNModelVersion::latestInputsVersionImplemented == 7); -int NNModelVersion::getInputsVersion(int modelVersion) { - if(modelVersion >= 8 && modelVersion <= 16) - return 7; - else if(modelVersion == 7) - return 6; - else if(modelVersion == 6) - return 5; - else if(modelVersion == 5) - return 4; - else if(modelVersion == 3 || modelVersion == 4) - return 3; +int NNModelVersion::getInputsVersion(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return 3; + } + break; + case 5: + if (!dotsGame) { + return 4; + } + break; + case 6: + if (!dotsGame) { + return 5; + } + break; + case 7: + if (!dotsGame) { + return 6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return 7; + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } -int NNModelVersion::getNumSpatialFeatures(int modelVersion) { - if(modelVersion >= 8 && modelVersion <= 16) - return NNInputs::NUM_FEATURES_SPATIAL_V7; - else if(modelVersion == 7) - return NNInputs::NUM_FEATURES_SPATIAL_V6; - else if(modelVersion == 6) - return NNInputs::NUM_FEATURES_SPATIAL_V5; - else if(modelVersion == 5) - return NNInputs::NUM_FEATURES_SPATIAL_V4; - else if(modelVersion == 3 || modelVersion == 4) - return NNInputs::NUM_FEATURES_SPATIAL_V3; +int NNModelVersion::getNumSpatialFeatures(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V3; + } + break; + case 5: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V4; + } + break; + case 6: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V5; + } + break; + case 7: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return NNInputs::NUM_FEATURES_SPATIAL_V7; // Use NUM_FEATURES_SPATIAL_V7_DOTS if it's value is changed + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } -int NNModelVersion::getNumGlobalFeatures(int modelVersion) { - if(modelVersion >= 8 && modelVersion <= 16) - return NNInputs::NUM_FEATURES_GLOBAL_V7; - else if(modelVersion == 7) - return NNInputs::NUM_FEATURES_GLOBAL_V6; - else if(modelVersion == 6) - return NNInputs::NUM_FEATURES_GLOBAL_V5; - else if(modelVersion == 5) - return NNInputs::NUM_FEATURES_GLOBAL_V4; - else if(modelVersion == 3 || modelVersion == 4) - return NNInputs::NUM_FEATURES_GLOBAL_V3; +int NNModelVersion::getNumGlobalFeatures(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V3; + } + break; + case 5: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V4; + } + break; + case 6: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V5; + } + break; + case 7: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return NNInputs::NUM_FEATURES_GLOBAL_V7; // Use NUM_FEATURES_GLOBAL_V7_DOTS if it's value is changed + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } diff --git a/cpp/neuralnet/modelversion.h b/cpp/neuralnet/modelversion.h index 5961b7bd7..3fe754089 100644 --- a/cpp/neuralnet/modelversion.h +++ b/cpp/neuralnet/modelversion.h @@ -12,12 +12,12 @@ namespace NNModelVersion { constexpr int oldestInputsVersionImplemented = 3; // Which V* feature version from NNInputs does a given model version consume? - int getInputsVersion(int modelVersion); + int getInputsVersion(int modelVersion, bool dotsGame); // Convenience functions, feeds forward the number of features and the size of // the row vector that the net takes as input - int getNumSpatialFeatures(int modelVersion); - int getNumGlobalFeatures(int modelVersion); + int getNumSpatialFeatures(int modelVersion, bool dotsGame); + int getNumGlobalFeatures(int modelVersion, bool dotsGame); // SGF metadata encoder input versions int getNumInputMetaChannels(int metaEncoderVersion); diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index 595fe78dc..6240f917c 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -66,7 +66,8 @@ NNEvaluator::NNEvaluator( const vector& gpuIdxByServerThr, const string& rSeed, bool doRandomize, - int defaultSymmetry + int defaultSymmetry, + const bool newDotsGame ) :modelName(mName), modelFileName(mFileName), @@ -106,12 +107,13 @@ NNEvaluator::NNEvaluator( currentDoRandomize(doRandomize), currentDefaultSymmetry(defaultSymmetry), currentBatchSize(maxBatchSz), - queryQueue() + queryQueue(), + dotsGame(newDotsGame) { - if(nnXLen > NNPos::MAX_BOARD_LEN) - throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); - if(nnYLen > NNPos::MAX_BOARD_LEN) - throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); + if(nnXLen > NNPos::MAX_BOARD_LEN_X) + throw StringError("Maximum supported nnEval board size x is " + Global::intToString(NNPos::MAX_BOARD_LEN_X)); + if(nnYLen > NNPos::MAX_BOARD_LEN_Y) + throw StringError("Maximum supported nnEval board size y is " + Global::intToString(NNPos::MAX_BOARD_LEN_Y)); if(maxBatchSize <= 0) throw StringError("maxBatchSize is negative: " + Global::intToString(maxBatchSize)); if(gpuIdxByServerThread.size() != numThreads) @@ -137,7 +139,6 @@ NNEvaluator::NNEvaluator( const ModelDesc& desc = NeuralNet::getModelDesc(loadedModel); internalModelName = desc.name; modelVersion = desc.modelVersion; - inputsVersion = NNModelVersion::getInputsVersion(modelVersion); numInputMetaChannels = desc.numInputMetaChannels; postProcessParams = desc.postProcessParams; computeContext = NeuralNet::createComputeContext( @@ -149,8 +150,8 @@ NNEvaluator::NNEvaluator( else { internalModelName = "random"; modelVersion = NNModelVersion::defaultModelVersion; - inputsVersion = NNModelVersion::getInputsVersion(modelVersion); } + inputsVersion = NNModelVersion::getInputsVersion(modelVersion, newDotsGame); //Reserve a decent amount above the batch size so that allocation is unlikely. queryQueue.reserve(maxBatchSize * 4 * gpuIdxByServerThread.size()); @@ -284,6 +285,9 @@ int NNEvaluator::getNNYLen() const { int NNEvaluator::getModelVersion() const { return modelVersion; } +bool NNEvaluator::getDotsGame() const { + return dotsGame; +} double NNEvaluator::getTrunkSpatialConvDepth() const { return NeuralNet::getModelDesc(loadedModel).getTrunkSpatialConvDepth(); } @@ -566,7 +570,7 @@ void NNEvaluator::serve( } } - NeuralNet::getOutput(gpuHandle, buf.inputBuffers, numRows, resultBufs.data(), outputBuf); + NeuralNet::getOutput(gpuHandle, buf.inputBuffers, numRows, resultBufs.data(), outputBuf, dotsGame); assert(outputBuf.size() == numRows); m_numRowsProcessed.fetch_add(numRows, std::memory_order_relaxed); @@ -776,29 +780,17 @@ void NNEvaluator::evaluate( buf.boardYSizeForServer = board.y_size; if(!debugSkipNeuralNet) { - const int rowSpatialLen = NNModelVersion::getNumSpatialFeatures(modelVersion) * nnXLen * nnYLen; + const int rowSpatialLen = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame) * nnXLen * nnYLen; if(buf.rowSpatialBuf.size() < rowSpatialLen) buf.rowSpatialBuf.resize(rowSpatialLen); - const int rowGlobalLen = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int rowGlobalLen = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); if(buf.rowGlobalBuf.size() < rowGlobalLen) buf.rowGlobalBuf.resize(rowGlobalLen); const int rowMetaLen = numInputMetaChannels; if(buf.rowMetaBuf.size() < rowMetaLen) buf.rowMetaBuf.resize(rowMetaLen); - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - NNInputs::fillRowV3(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 4) - NNInputs::fillRowV4(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 5) - NNInputs::fillRowV5(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 6) - NNInputs::fillRowV6(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 7) - NNInputs::fillRowV7(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else - ASSERT_UNREACHABLE; + NNInputs::fillRowVN(inputsVersion, board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); if(rowMetaLen > 0) { if(sgfMeta == NULL) @@ -868,7 +860,18 @@ void NNEvaluator::evaluate( assert(nextPlayer == history.presumedNextMovePla); for(int i = 0; i= 13 && ySize >= 13) { @@ -882,6 +885,7 @@ void NNEvaluator::evaluate( } } + legalCount = 0; for(int i = 0; i& gpuIdxByServerThread, const std::string& randSeed, bool doRandomize, - int defaultSymmetry + int defaultSymmetry, + bool newDotsGame ); ~NNEvaluator(); @@ -124,6 +125,7 @@ class NNEvaluator { int getNNXLen() const; int getNNYLen() const; int getModelVersion() const; + bool getDotsGame() const; double getTrunkSpatialConvDepth() const; enabled_t getUsingFP16Mode() const; enabled_t getUsingNHWCMode() const; @@ -269,6 +271,8 @@ class NNEvaluator { //Queued up requests ThreadSafeQueue queryQueue; + bool dotsGame; + public: //Helper, for internal use only void serve(NNServerBuf& buf, Rand& rand, int gpuIdxForThisThread, int serverThreadIdx); diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 5fad7ecc3..0313b5b8b 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -8,7 +8,7 @@ int NNPos::xyToPos(int x, int y, int nnXLen) { int NNPos::locToPos(Loc loc, int boardXSize, int nnXLen, int nnYLen) { if(loc == Board::PASS_LOC) return nnXLen * nnYLen; - else if(loc == Board::NULL_LOC) + if(loc == Board::NULL_LOC || loc == Board::RESIGN_LOC) // NN normally shouldn't return a resigning move return nnXLen * (nnYLen + 1); return Location::getY(loc,boardXSize) * nnXLen + Location::getX(loc,boardXSize); } @@ -237,8 +237,40 @@ void NNInputs::fillScoring( bool groupTax, float* scoring ) { - if(!groupTax) { - std::fill(scoring, scoring + Board::MAX_ARR_SIZE, 0.0f); + std::fill_n(scoring, Board::MAX_ARR_SIZE, 0.0f); + + if (board.isDots()) { + for(int y = 0; y sym_start_pos_moves; + + if (!board.start_pos_moves.empty()) { + sym_start_pos_moves.reserve(board.start_pos_moves.size()); + for (const auto start_pos_move : board.start_pos_moves) { + const Loc loc = start_pos_move.loc; + const int x = Location::getX(loc, board.x_size); + const int y = Location::getY(loc, board.x_size); + sym_start_pos_moves.emplace_back(getSymLoc(x, y), start_pos_move.pla); + } + vector startPosMoves; + bool randomized; + symRules.startPos = Rules::recognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, startPosMoves, randomized); + symRules.startPosIsRandom = randomized; + } + + Board symBoard(sym_x_size, sym_y_size, symRules); + symBoard.setStonesFailIfNoLibs(sym_start_pos_moves, true); + Loc symKoLoc = Board::NULL_LOC; for(int y = 0; y 0 + Reserved_13, // 13: encore phase > 1 + WinByGrounding_14, // 14: does a pass end the current phase given the ruleset and history? + PlayoutDoublingAdvantageIsEnabled_15, // 15: playoutDoublingAdvantage is enabled (handicap play) + PlayoutDoublingAdvantage_16, // 16: playoutDoublingAdvantage + CaptureEmpty_17, // 17: button + FieldSizeKomiParity_18, // 18: board size komi parity + + COUNT, // = 19 + }; + const int NUM_FEATURES_SPATIAL_V3 = 22; const int NUM_FEATURES_GLOBAL_V3 = 14; @@ -71,11 +129,26 @@ namespace NNInputs { const int NUM_FEATURES_SPATIAL_V7 = 22; const int NUM_FEATURES_GLOBAL_V7 = 19; + constexpr int NUM_FEATURES_SPATIAL_V7_DOTS = static_cast(DotsSpatialFeature::COUNT); + constexpr int NUM_FEATURES_GLOBAL_V7_DOTS = static_cast(DotsGlobalFeature::COUNT); + + static_assert(NUM_FEATURES_SPATIAL_V7_DOTS == NUM_FEATURES_SPATIAL_V7); // Might be changed later if needed + static_assert(NUM_FEATURES_GLOBAL_V7_DOTS == NUM_FEATURES_GLOBAL_V7); + Hash128 getHash( const Board& board, const BoardHistory& boardHistory, Player nextPlayer, const MiscNNInputParams& nnInputParams ); + int getNumberOfSpatialFeatures(int version, bool isDots); + int getNumberOfGlobalFeatures(int version, bool isDots); + + void fillRowVN( + int version, + const Board& board, const BoardHistory& hist, Player nextPlayer, + const MiscNNInputParams& nnInputParams, + int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal + ); void fillRowV3( const Board& board, const BoardHistory& boardHistory, Player nextPlayer, const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal @@ -96,6 +169,11 @@ namespace NNInputs { const Board& board, const BoardHistory& boardHistory, Player nextPlayer, const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal ); + void fillRowV7Dots( + const Board& board, const BoardHistory& hist, Player nextPlayer, + const MiscNNInputParams& nnInputParams, + int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal + ); //If groupTax is specified, for each color region of area, reduce weight on empty spaces equally to reduce the total sum by 2. //(but should handle seki correctly) @@ -162,6 +240,15 @@ struct NNOutput { namespace SymmetryHelpers { //A symmetry is 3 bits flipY(bit 0), flipX(bit 1), transpose(bit 2). They are applied in that order. //The first four symmetries only reflect, and do not transpose X and Y. + constexpr int SYMMETRY_NONE = 0; + constexpr int SYMMETRY_FLIP_Y = 1; + constexpr int SYMMETRY_FLIP_X = 2; + constexpr int SYMMETRY_FLIP_Y_X = 3; // Rotate 180 + constexpr int SYMMETRY_TRANSPOSE = 4; // Rotate 90 CW + Flip X; Rotate 90 CCW + FlipY + constexpr int SYMMETRY_TRANSPOSE_FLIP_X = 5; // Rotate 90 CW + constexpr int SYMMETRY_TRANSPOSE_FLIP_Y = 6; // Rotate 90 CCW + constexpr int SYMMETRY_TRANSPOSE_FLIP_Y_X = 7; // Rotate 90 CW + Flip Y; Rotate 90 CCW + FlipX + constexpr int NUM_SYMMETRIES = 8; constexpr int NUM_SYMMETRIES_WITHOUT_TRANSPOSE = 4; @@ -209,6 +296,8 @@ namespace SymmetryHelpers { void getSymmetryDifferences( const Board& board, const Board& other, double maxDifferenceToReport, double symmetryDifferences[NUM_SYMMETRIES] ); + + std::string symmetryToString(int symmetry); } //Utility functions for computing the "scoreValue", the unscaled utility of various numbers of points, prior to multiplication by diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp new file mode 100644 index 000000000..7b1d56b27 --- /dev/null +++ b/cpp/neuralnet/nninputsdots.cpp @@ -0,0 +1,129 @@ +#include "../neuralnet/nninputs.h" + +using namespace std; + +void NNInputs::fillRowV7Dots( + const Board& board, const BoardHistory& hist, Player nextPlayer, + const MiscNNInputParams& nnInputParams, + int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal +) { + assert(nnXLen <= NNPos::MAX_BOARD_LEN_X); + assert(nnYLen <= NNPos::MAX_BOARD_LEN_Y); + assert(board.x_size <= nnXLen); + assert(board.y_size <= nnYLen); + std::fill_n(rowBin, NUM_FEATURES_SPATIAL_V7_DOTS * nnXLen * nnYLen,false); + std::fill_n(rowGlobal, NUM_FEATURES_GLOBAL_V7_DOTS, 0.0f); + + const Player pla = nextPlayer; + const Player opp = getOpp(pla); + const int xSize = board.x_size; + const int ySize = board.y_size; + + int featureStride; + int posStride; + if(useNHWC) { + featureStride = 1; + posStride = NUM_FEATURES_SPATIAL_V7_DOTS; + } + else { + featureStride = nnXLen * nnYLen; + posStride = 1; + } + + const Rules& rules = hist.rules; + + vector captures; + vector bases; + board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + [[maybe_unused]] int deadDotsCount = 0; + + auto setSpatial = [&](const int pos, const DotsSpatialFeature spatialFeature) { + setRowBin(rowBin, pos, static_cast(spatialFeature), 1.0f, posStride, featureStride); + }; + + auto setGlobal = [&](const DotsGlobalFeature globalFeature, const float value = 1.0f) { + rowGlobal[static_cast(globalFeature)] = value; + }; + + for(int y = 0; y(xSize * ySize); + //Bound komi just in case + if(selfKomi > bArea+NNPos::KOMI_CLIP_RADIUS) + selfKomi = bArea+NNPos::KOMI_CLIP_RADIUS; + if(selfKomi < -bArea-NNPos::KOMI_CLIP_RADIUS) + selfKomi = -bArea-NNPos::KOMI_CLIP_RADIUS; + setGlobal(DotsGlobalFeature::Komi_5, selfKomi / NNPos::KOMI_CLIP_RADIUS); + + if (rules.multiStoneSuicideLegal) { + setGlobal(DotsGlobalFeature::Suicide_8); + } + + if (rules.dotsCaptureEmptyBases) { + setGlobal(DotsGlobalFeature::CaptureEmpty_17); + } + + if (hist.winOrEffectiveDrawByGrounding(board, pla)) { + // Train to better understand grounding + setGlobal(DotsGlobalFeature::WinByGrounding_14); + } + + setGlobal(DotsGlobalFeature::FieldSizeKomiParity_18, 0.0f); // TODO: implement later +} \ No newline at end of file diff --git a/cpp/neuralnet/nninterface.h b/cpp/neuralnet/nninterface.h index e5932a6d5..e44de9885 100644 --- a/cpp/neuralnet/nninterface.h +++ b/cpp/neuralnet/nninterface.h @@ -112,7 +112,8 @@ namespace NeuralNet { InputBuffers* buffers, int numBatchEltsFilled, NNResultBuf** inputBufs, - std::vector& outputs + const std::vector& outputs, + bool dotsGame ); diff --git a/cpp/neuralnet/openclbackend.cpp b/cpp/neuralnet/openclbackend.cpp index c4bcf2728..aee14e030 100644 --- a/cpp/neuralnet/openclbackend.cpp +++ b/cpp/neuralnet/openclbackend.cpp @@ -2482,13 +2482,13 @@ struct Model { nnXLen = nnX; nnYLen = nnY; - if(nnXLen > NNPos::MAX_BOARD_LEN) - throw StringError(Global::strprintf("nnXLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", - nnXLen, NNPos::MAX_BOARD_LEN + if(nnXLen > NNPos::MAX_BOARD_LEN_X) + throw StringError(Global::strprintf("nnXLen (%d) is greater than NNPos::MAX_BOARD_LEN_X (%d)", + nnXLen, NNPos::MAX_BOARD_LEN_X )); - if(nnYLen > NNPos::MAX_BOARD_LEN) - throw StringError(Global::strprintf("nnYLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", - nnYLen, NNPos::MAX_BOARD_LEN + if(nnYLen > NNPos::MAX_BOARD_LEN_Y) + throw StringError(Global::strprintf("nnYLen (%d) is greater than NNPos::MAX_BOARD_LEN_Y (%d)", + nnYLen, NNPos::MAX_BOARD_LEN_Y )); numInputChannels = desc->numInputChannels; @@ -2499,12 +2499,12 @@ struct Model { numScoreValueChannels = desc->numScoreValueChannels; numOwnershipChannels = desc->numOwnershipChannels; - int numFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + int numFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, desc->isDotsGame); if(numInputChannels != numFeatures) throw StringError(Global::strprintf("Neural net numInputChannels (%d) was not the expected number based on version (%d)", numInputChannels, numFeatures )); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, desc->isDotsGame); if(numInputGlobalChannels != numGlobalFeatures) throw StringError(Global::strprintf("Neural net numInputGlobalChannels (%d) was not the expected number based on version (%d)", numInputGlobalChannels, numGlobalFeatures @@ -2890,8 +2890,8 @@ struct InputBuffers { singleScoreValueResultElts = (size_t)m.numScoreValueChannels; singleOwnershipResultElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; - assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); - assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion, m.isDotsGame) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion, m.isDotsGame) == m.numInputGlobalChannels); if(m.numInputMetaChannels > 0) { assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); } @@ -2952,14 +2952,13 @@ void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { delete inputBuffers; } - void NeuralNet::getOutput( ComputeHandle* gpuHandle, InputBuffers* inputBuffers, - int numBatchEltsFilled, + const int numBatchEltsFilled, NNResultBuf** inputBufs, - vector& outputs -) { + const vector& outputs, + const bool dotsGame) { assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); assert(numBatchEltsFilled > 0); const int batchSize = numBatchEltsFilled; @@ -2967,8 +2966,8 @@ void NeuralNet::getOutput( const int nnYLen = gpuHandle->nnYLen; const int modelVersion = gpuHandle->model->modelVersion; - const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); const int numMetaFeatures = inputBuffers->singleInputMetaElts; assert(numSpatialFeatures == gpuHandle->model->numInputChannels); assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); diff --git a/cpp/neuralnet/opencltuner.cpp b/cpp/neuralnet/opencltuner.cpp index c28b755b9..e629faadd 100644 --- a/cpp/neuralnet/opencltuner.cpp +++ b/cpp/neuralnet/opencltuner.cpp @@ -3244,8 +3244,8 @@ OpenCLTuneParams OpenCLTuner::loadOrAutoTune( //If not re-tuning per board size, then check if the tune config for the full size is there //And set the nnXLen and nnYLen we'll use for tuning to the full size if(!openCLReTunePerBoardSize) { - nnXLen = NNPos::MAX_BOARD_LEN; - nnYLen = NNPos::MAX_BOARD_LEN; + nnXLen = NNPos::MAX_BOARD_LEN_X; + nnYLen = NNPos::MAX_BOARD_LEN_Y; openCLTunerFile = dir + "/" + OpenCLTuner::defaultFileName(gpuName, nnXLen, nnYLen, modelInfo); try { OpenCLTuneParams loadedParams = loadFromTunerFile(openCLTunerFile,logger); @@ -3464,8 +3464,8 @@ void OpenCLTuner::autoTuneEverything( } for(ModelInfoForTuning modelInfo : modelInfos) { - int nnXLen = NNPos::MAX_BOARD_LEN; - int nnYLen = NNPos::MAX_BOARD_LEN; + int nnXLen = NNPos::MAX_BOARD_LEN_X; + int nnYLen = NNPos::MAX_BOARD_LEN_Y; string dir = OpenCLTuner::defaultDirectory(true,homeDataDirOverride); string openCLTunerFile = dir + "/" + OpenCLTuner::defaultFileName(gpuName, nnXLen, nnYLen, modelInfo); try { diff --git a/cpp/neuralnet/opencltuner.h b/cpp/neuralnet/opencltuner.h index 8f6071ff1..994b66871 100644 --- a/cpp/neuralnet/opencltuner.h +++ b/cpp/neuralnet/opencltuner.h @@ -180,8 +180,8 @@ struct OpenCLTuneParams { }; namespace OpenCLTuner { - constexpr int DEFAULT_X_SIZE = NNPos::MAX_BOARD_LEN; - constexpr int DEFAULT_Y_SIZE = NNPos::MAX_BOARD_LEN; + constexpr int DEFAULT_X_SIZE = NNPos::MAX_BOARD_LEN_X; + constexpr int DEFAULT_Y_SIZE = NNPos::MAX_BOARD_LEN_Y; constexpr int DEFAULT_BATCH_SIZE = 4; constexpr int DEFAULT_WINOGRAD_3X3_TILE_SIZE = 4; diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index c6df9f251..b8472a818 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -115,6 +115,7 @@ struct TRTModel { vector> extraWeights; int modelVersion; + bool dotsGame; uint8_t tuneHash[32]; IOptimizationProfile* profile; unique_ptr network; @@ -196,6 +197,7 @@ struct ModelParser { } model->modelVersion = modelDesc->modelVersion; + model->dotsGame = modelDesc->isDotsGame; network->setName(modelDesc->name.c_str()); initInputs(); @@ -247,13 +249,13 @@ struct ModelParser { int numInputGlobalChannels = modelDesc->numInputGlobalChannels; int numInputMetaChannels = modelDesc->numInputMetaChannels; - int numFeatures = NNModelVersion::getNumSpatialFeatures(model->modelVersion); + int numFeatures = NNModelVersion::getNumSpatialFeatures(model->modelVersion, model->dotsGame); if(numInputChannels != numFeatures) throw StringError(Global::strprintf( "Neural net numInputChannels (%d) was not the expected number based on version (%d)", numInputChannels, numFeatures)); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(model->modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(model->modelVersion, model->dotsGame); if(numInputGlobalChannels != numGlobalFeatures) throw StringError(Global::strprintf( "Neural net numInputGlobalChannels (%d) was not the expected number based on version (%d)", @@ -1586,8 +1588,8 @@ struct InputBuffers { singleOwnershipResultElts = m.numOwnershipChannels * nnXLen * nnYLen; singleOwnershipResultBytes = singleOwnershipResultElts * sizeof(float); - assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); - assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion, m.isDotsGame) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion, m.isDotsGame) == m.numInputGlobalChannels); if(m.numInputMetaChannels > 0) { assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); } @@ -1631,7 +1633,8 @@ void NeuralNet::getOutput( InputBuffers* inputBuffers, int numBatchEltsFilled, NNResultBuf** inputBufs, - vector& outputs) { + const vector& outputs, + const bool dotsGame) { assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); assert(numBatchEltsFilled > 0); @@ -1640,8 +1643,8 @@ void NeuralNet::getOutput( const int nnYLen = gpuHandle->ctx->nnYLen; const int modelVersion = gpuHandle->modelVersion; - const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); const int numMetaFeatures = inputBuffers->singleInputMetaElts; assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts); diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 631b3cf3d..db21f1ab0 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -1,13 +1,13 @@ #include "../program/play.h" -#include "../core/global.h" #include "../core/fileutils.h" +#include "../core/global.h" #include "../core/timer.h" +#include "../dataio/files.h" #include "../program/playutils.h" #include "../program/setup.h" #include "../search/asyncbot.h" #include "../search/searchnode.h" -#include "../dataio/files.h" #include "../core/test.h" @@ -15,8 +15,8 @@ using namespace std; //---------------------------------------------------------------------------------------------------------- -InitialPosition::InitialPosition() - :board(),hist(),pla(C_EMPTY) +InitialPosition::InitialPosition(const Rules& rules) + :board(rules),hist(rules),pla(C_EMPTY) {} InitialPosition::InitialPosition(const Board& b, const BoardHistory& h, Player p, bool plainFork, bool sekiFork, bool hintFork, double tw) :board(b),hist(h),pla(p),isPlainFork(plainFork),isSekiFork(sekiFork),isHintFork(hintFork),trainingWeight(tw) @@ -93,32 +93,38 @@ GameInitializer::GameInitializer(ConfigParser& cfg, Logger& logger, const string } void GameInitializer::initShared(ConfigParser& cfg, Logger& logger) { + dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); - allowedKoRuleStrs = cfg.getStrings("koRules", Rules::koRuleStrings()); - allowedScoringRuleStrs = cfg.getStrings("scoringRules", Rules::scoringRuleStrings()); - allowedTaxRuleStrs = cfg.getStrings("taxRules", Rules::taxRuleStrings()); - allowedMultiStoneSuicideLegals = cfg.getBools("multiStoneSuicideLegals"); - allowedButtons = cfg.getBools("hasButtons"); - - for(size_t i = 0; i < allowedKoRuleStrs.size(); i++) - allowedKoRules.push_back(Rules::parseKoRule(allowedKoRuleStrs[i])); - for(size_t i = 0; i < allowedScoringRuleStrs.size(); i++) - allowedScoringRules.push_back(Rules::parseScoringRule(allowedScoringRuleStrs[i])); - for(size_t i = 0; i < allowedTaxRuleStrs.size(); i++) - allowedTaxRules.push_back(Rules::parseTaxRule(allowedTaxRuleStrs[i])); - - if(allowedKoRules.size() <= 0) - throw IOError("koRules must have at least one value in " + cfg.getFileName()); - if(allowedScoringRules.size() <= 0) - throw IOError("scoringRules must have at least one value in " + cfg.getFileName()); - if(allowedTaxRules.size() <= 0) - throw IOError("taxRules must have at least one value in " + cfg.getFileName()); - if(allowedMultiStoneSuicideLegals.size() <= 0) - throw IOError("multiStoneSuicideLegals must have at least one value in " + cfg.getFileName()); - if(allowedButtons.size() <= 0) - throw IOError("hasButtons must have at least one value in " + cfg.getFileName()); + if (dotsGame) { + if (cfg.containsAny({"koRules", "scoringRules", "taxRules", "hasButtons"})) { + throw IOError("koRules, scoringRules, taxRules, hasButtons are not applicable for Dots game. Please remove them from " + cfg.getFileName()); + } + + allowedCaptureEmtpyBasesRules = cfg.getBools(DOTS_CAPTURE_EMPTY_BASES_KEY); + if (allowedCaptureEmtpyBasesRules.empty()) + throw IOError(DOTS_CAPTURE_EMPTY_BASES_KEY + " must have at least one value in " + cfg.getFileName()); + } else { + allowedKoRuleStrs = cfg.getStrings("koRules", Rules::koRuleStrings()); + allowedScoringRuleStrs = cfg.getStrings("scoringRules", Rules::scoringRuleStrings()); + allowedTaxRuleStrs = cfg.getStrings("taxRules", Rules::taxRuleStrings()); + allowedButtons = cfg.getBools("hasButtons"); + + for(size_t i = 0; i < allowedKoRuleStrs.size(); i++) + allowedKoRules.push_back(Rules::parseKoRule(allowedKoRuleStrs[i])); + for(size_t i = 0; i < allowedScoringRuleStrs.size(); i++) + allowedScoringRules.push_back(Rules::parseScoringRule(allowedScoringRuleStrs[i])); + for(size_t i = 0; i < allowedTaxRuleStrs.size(); i++) + allowedTaxRules.push_back(Rules::parseTaxRule(allowedTaxRuleStrs[i])); + + if(allowedKoRules.size() <= 0) + throw IOError("koRules must have at least one value in " + cfg.getFileName()); + if(allowedScoringRules.size() <= 0) + throw IOError("scoringRules must have at least one value in " + cfg.getFileName()); + if(allowedTaxRules.size() <= 0) + throw IOError("taxRules must have at least one value in " + cfg.getFileName()); + if(allowedButtons.size() <= 0) + throw IOError("hasButtons must have at least one value in " + cfg.getFileName()); - { bool hasAreaScoring = false; for(int i = 0; i or komiAuto=True in config"); - komiMean = cfg.contains("komiMean") ? cfg.getFloat("komiMean",Rules::MIN_USER_KOMI,Rules::MAX_USER_KOMI) : 7.5f; + komiMean = cfg.contains("komiMean") ? cfg.getFloat("komiMean",Rules::MIN_USER_KOMI,Rules::MAX_USER_KOMI) : + dotsGame ? 0.0f : 7.5f; komiStdev = cfg.contains("komiStdev") ? cfg.getFloat("komiStdev",0.0f,60.0f) : 0.0f; handicapProb = cfg.contains("handicapProb") ? cfg.getDouble("handicapProb",0.0,1.0) : 0.0; handicapCompensateKomiProb = cfg.contains("handicapCompensateKomiProb") ? cfg.getDouble("handicapCompensateKomiProb",0.0,1.0) : 0.0; @@ -243,6 +268,10 @@ void GameInitializer::initShared(ConfigParser& cfg, Logger& logger) { startPosesProb = 0.0; if(cfg.contains("startPosesFromSgfDir")) { + if (!allowedStartPosRules.empty()) { + throw StringError("startPosesFromSgfDir is not compatible with " + START_POSES_KEY + ". Please specify only one key."); + } + startPoses.clear(); startPosCumProbs.clear(); startPosesProb = cfg.getDouble("startPosesProb",0.0,1.0); @@ -459,6 +488,9 @@ int GameInitializer::getMaxBoardXSize() const { int GameInitializer::getMaxBoardYSize() const { return maxBoardYSize; } +bool GameInitializer::isDotsGame() const { + return dotsGame; +} Rules GameInitializer::createRules() { lock_guard lock(createGameMutex); @@ -466,16 +498,25 @@ Rules GameInitializer::createRules() { } Rules GameInitializer::createRulesUnsynchronized() { - Rules rules; - rules.koRule = allowedKoRules[rand.nextUInt((uint32_t)allowedKoRules.size())]; - rules.scoringRule = allowedScoringRules[rand.nextUInt((uint32_t)allowedScoringRules.size())]; - rules.taxRule = allowedTaxRules[rand.nextUInt((uint32_t)allowedTaxRules.size())]; - rules.multiStoneSuicideLegal = allowedMultiStoneSuicideLegals[rand.nextUInt((uint32_t)allowedMultiStoneSuicideLegals.size())]; + auto rules = Rules(dotsGame); + rules.multiStoneSuicideLegal = allowedMultiStoneSuicideLegals[rand.nextUInt(static_cast(allowedMultiStoneSuicideLegals.size()))]; + if (!allowedStartPosRules.empty()) { + rules.startPos = allowedStartPosRules[rand.nextUInt(static_cast(allowedStartPosRules.size()))]; + } + if (!allowedStartPosRandomRules.empty()) { + rules.startPosIsRandom = allowedStartPosRandomRules[rand.nextUInt(static_cast(allowedStartPosRandomRules.size()))]; + } - if(rules.scoringRule == Rules::SCORING_AREA) - rules.hasButton = allowedButtons[rand.nextUInt((uint32_t)allowedButtons.size())]; - else - rules.hasButton = false; + if (dotsGame) { + rules.dotsCaptureEmptyBases = allowedCaptureEmtpyBasesRules[rand.nextUInt(static_cast(allowedCaptureEmtpyBasesRules.size()))]; + } else { + rules.koRule = allowedKoRules[rand.nextUInt((uint32_t)allowedKoRules.size())]; + rules.scoringRule = allowedScoringRules[rand.nextUInt((uint32_t)allowedScoringRules.size())]; + rules.taxRule = allowedTaxRules[rand.nextUInt((uint32_t)allowedTaxRules.size())]; + rules.hasButton = rules.scoringRule == Rules::SCORING_AREA + ? allowedButtons[rand.nextUInt((uint32_t)allowedButtons.size())] + : false; + } return rules; } @@ -588,9 +629,12 @@ void GameInitializer::createGameSharedUnsynchronized( else { int xSize = allowedBSizes[bSizeIdx].first; int ySize = allowedBSizes[bSizeIdx].second; - board = Board(xSize,ySize); + board = Board(xSize,ySize,rules); + board.setStartPos(rand); + pla = P_BLACK; hist.clear(board,pla,rules,0); + hist.setInitialTurnNumber(rules.getNumOfStartPosStones()); extraBlackAndKomi = PlayUtils::chooseExtraBlackAndKomi( komiMean, komiStdev, komiAllowIntegerProb, @@ -751,12 +795,12 @@ pair MatchPairer::getMatchupPairUnsynchronized() { //---------------------------------------------------------------------------------------------------------- -static void failIllegalMove(Search* bot, Logger& logger, Board board, Loc loc) { +static void failIllegalMove(const Search* bot, Logger& logger, const Board& board, Loc loc) { ostringstream sout; sout << "Bot returned null location or illegal move!?!" << "\n"; sout << board << "\n"; sout << bot->getRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(bot->getRootPla()) << "\n"; + sout << "Pla: " << PlayerIO::playerToString(bot->getRootPla(),board.isDots()) << "\n"; sout << "Loc: " << Location::toString(loc,bot->getRootBoard()) << "\n"; logger.write(sout.str()); bot->getRootBoard().checkConsistency(); @@ -787,7 +831,7 @@ static void logSearch(Search* bot, Logger& logger, Loc loc, OtherGameProperties static Loc chooseRandomForkingMove(const NNOutput* nnOutput, const Board& board, const BoardHistory& hist, Player pla, Rand& gameRand, Loc banMove) { double r = gameRand.nextDouble(); - bool allowPass = true; + bool allowPass = !hist.rules.isDots || hist.winOrEffectiveDrawByGrounding(board, pla); //70% of the time, do a random temperature 1 policy move if(r < 0.70) return PlayUtils::chooseRandomPolicyMove(nnOutput, board, hist, pla, gameRand, 1.0, allowPass, banMove); @@ -1307,7 +1351,7 @@ FinishedGameData* Play::runGame( std::function checkForNewNNEval, std::function&, const std::vector&, const std::vector&, const Search*)> onEachMove ) { - FinishedGameData* gameData = new FinishedGameData(); + FinishedGameData* gameData = new FinishedGameData(startHist.rules); Board board(startBoard); BoardHistory hist(startHist); @@ -1316,7 +1360,7 @@ FinishedGameData* Play::runGame( assert(!(playSettings.forSelfPlay && !clearBotBeforeSearch)); if(extraBlackAndKomi.makeGameFairForEmptyBoard) { - Board b(startBoard.x_size,startBoard.y_size); + Board b(startBoard.x_size, startBoard.y_size, startBoard.rules); Player makeFairPla = P_BLACK; if(playSettings.flipKomiProbWhenNoCompensate != 0.0 && gameRand.nextBool(playSettings.flipKomiProbWhenNoCompensate)) makeFairPla = P_WHITE; @@ -1506,8 +1550,10 @@ FinishedGameData* Play::runGame( for(int i = 0; iwaitUntilFalse(); if(shouldStop != nullptr && shouldStop()) @@ -1723,9 +1769,11 @@ FinishedGameData* Play::runGame( assert(gameData->finalFullArea == NULL); assert(gameData->finalOwnership == NULL); assert(gameData->finalSekiAreas == NULL); - gameData->finalFullArea = new Color[Board::MAX_ARR_SIZE]; gameData->finalOwnership = new Color[Board::MAX_ARR_SIZE]; - gameData->finalSekiAreas = new bool[Board::MAX_ARR_SIZE]; + if (!hist.rules.isDots) { + gameData->finalFullArea = new Color[Board::MAX_ARR_SIZE]; + gameData->finalSekiAreas = new bool[Board::MAX_ARR_SIZE]; + } if(hist.isGameFinished && hist.isNoResult) { finalValueTargets.win = 0.0f; @@ -1735,9 +1783,11 @@ FinishedGameData* Play::runGame( //Fill with empty so that we use "nobody owns anything" as the training target. //Although in practice actually the training normally weights by having a result or not, so it doesn't matter what we fill. - std::fill(gameData->finalFullArea,gameData->finalFullArea+Board::MAX_ARR_SIZE,C_EMPTY); - std::fill(gameData->finalOwnership,gameData->finalOwnership+Board::MAX_ARR_SIZE,C_EMPTY); - std::fill(gameData->finalSekiAreas,gameData->finalSekiAreas+Board::MAX_ARR_SIZE,false); + std::fill_n(gameData->finalOwnership, Board::MAX_ARR_SIZE,C_EMPTY); + if (!hist.rules.isDots) { + std::fill_n(gameData->finalFullArea, Board::MAX_ARR_SIZE,C_EMPTY); + std::fill_n(gameData->finalSekiAreas, Board::MAX_ARR_SIZE,false); + } } else { //Relying on this to be idempotent, so that we can get the final territory map @@ -1752,7 +1802,7 @@ FinishedGameData* Play::runGame( finalValueTargets.lead = finalValueTargets.score; //Fill full and seki areas - { + if (!hist.rules.isDots) { board.calculateArea(gameData->finalFullArea, true, true, true, hist.rules.multiStoneSuicideLegal); Color* independentLifeArea = new Color[Board::MAX_ARR_SIZE]; @@ -2163,10 +2213,11 @@ void Play::maybeForkGame( ASSERT_UNREACHABLE; } - Board board; + const Rules& rules = finishedGameData->startHist.rules; + Board board(rules); Player pla; - BoardHistory hist; - replayGameUpToMove(finishedGameData, moveIdx, finishedGameData->startHist.rules, board, hist, pla); + BoardHistory hist(rules); + replayGameUpToMove(finishedGameData, moveIdx, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) return; @@ -2251,9 +2302,9 @@ void Play::maybeSekiForkGame( Rules rules = finishedGameData->startHist.rules; rules = gameInit->randomizeScoringAndTaxRules(rules,gameRand); - Board board; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); replayGameUpToMove(finishedGameData, moveIdx, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) @@ -2283,12 +2334,13 @@ void Play::maybeHintForkGame( if(!hintFork) return; - Board board; + const Rules& rules = finishedGameData->startHist.rules; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); testAssert(finishedGameData->startHist.moveHistory.size() < 0x1FFFffff); int moveIdxToReplayTo = (int)finishedGameData->startHist.moveHistory.size(); - replayGameUpToMove(finishedGameData, moveIdxToReplayTo, finishedGameData->startHist.rules, board, hist, pla); + replayGameUpToMove(finishedGameData, moveIdxToReplayTo, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) return; @@ -2372,9 +2424,10 @@ FinishedGameData* GameRunner::runGame( } } - Board board; + const Rules& rules = Rules::getDefault(gameInit->isDotsGame()); + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); ExtraBlackAndKomi extraBlackAndKomi; OtherGameProperties otherGameProps; if(playSettings.forSelfPlay) { @@ -2407,20 +2460,25 @@ FinishedGameData* GameRunner::runGame( clearBotBeforeSearchThisGame = true; } - //In 2% of games, don't autoterminate the game upon all pass alive, to just provide a tiny bit of training data on positions that occur - //as both players must wrap things up manually, because within the search we don't autoterminate games, meaning that the NN will get - //called on positions that occur after the game would have been autoterminated. - bool doEndGameIfAllPassAlive = playSettings.forSelfPlay ? gameRand.nextBool(0.98) : true; + bool doEndGameIfAllPassAlive; + if (rules.isDots) { + doEndGameIfAllPassAlive = false; + } else { + //In 2% of games, don't autoterminate the game upon all pass alive, to just provide a tiny bit of training data on positions that occur + //as both players must wrap things up manually, because within the search we don't autoterminate games, meaning that the NN will get + //called on positions that occur after the game would have been autoterminated. + doEndGameIfAllPassAlive = playSettings.forSelfPlay ? gameRand.nextBool(0.98) : true; + } Search* botB; Search* botW; if(botSpecB.botIdx == botSpecW.botIdx) { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, nullptr, hist.rules); botW = botB; } else { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B"); - botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W"); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", nullptr, hist.rules); + botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", nullptr, hist.rules); } if(afterInitialization != nullptr) { if(botSpecB.botIdx == botSpecW.botIdx) { @@ -2469,7 +2527,7 @@ FinishedGameData* GameRunner::runGame( assert(finishedGameData != NULL); Play::maybeForkGame(finishedGameData, forkData, playSettings, gameRand, botB); - if(!usedSekiForkHackPosition) { + if(!hist.rules.isDots && !usedSekiForkHackPosition) { Play::maybeSekiForkGame(finishedGameData, forkData, playSettings, gameInit, gameRand); } Play::maybeHintForkGame(finishedGameData, forkData, otherGameProps); diff --git a/cpp/program/play.h b/cpp/program/play.h index 8cadadf1f..0849f8b58 100644 --- a/cpp/program/play.h +++ b/cpp/program/play.h @@ -24,7 +24,7 @@ struct InitialPosition { bool isHintFork; double trainingWeight; - InitialPosition(); + explicit InitialPosition(const Rules& rules); InitialPosition(const Board& board, const BoardHistory& hist, Player pla, bool isPlainFork, bool isSekiFork, bool isHintFork, double trainingWeight); ~InitialPosition(); }; @@ -119,6 +119,8 @@ class GameInitializer { int getMaxBoardXSize() const; int getMaxBoardYSize() const; + bool isDotsGame() const; + private: void initShared(ConfigParser& cfg, Logger& logger); void createGameSharedUnsynchronized( @@ -134,6 +136,11 @@ class GameInitializer { std::mutex createGameMutex; Rand rand; + bool dotsGame; + std::vector allowedCaptureEmtpyBasesRules; + std::vector allowedStartPosRules; + std::vector allowedStartPosRandomRules; + std::vector allowedKoRuleStrs; std::vector allowedScoringRuleStrs; std::vector allowedTaxRuleStrs; diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index b8ef70e5c..3944d91f7 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -111,7 +111,7 @@ Loc PlayUtils::chooseRandomLegalMove(const Board& board, const BoardHistory& his Loc locs[Board::MAX_ARR_SIZE]; testAssert(pla == hist.presumedNextMovePla); for(Loc loc = 0; loc < Board::MAX_ARR_SIZE; loc++) { - if(hist.isLegal(board,loc,pla) && loc != banMove) { + if(loc != Board::RESIGN_LOC && hist.isLegal(board,loc,pla) && loc != banMove) { locs[numLegalMoves] = loc; numLegalMoves += 1; } @@ -128,7 +128,7 @@ int PlayUtils::chooseRandomLegalMoves(const Board& board, const BoardHistory& hi Loc locs[Board::MAX_ARR_SIZE]; testAssert(pla == hist.presumedNextMovePla); for(Loc loc = 0; loc < Board::MAX_ARR_SIZE; loc++) { - if(hist.isLegal(board,loc,pla)) { + if(loc != Board::RESIGN_LOC && hist.isLegal(board,loc,pla)) { locs[numLegalMoves] = loc; numLegalMoves += 1; } @@ -256,6 +256,7 @@ void PlayUtils::initializeGameUsingPolicy( //Rarely, playing the random moves out this way will end the game if(doEndGameIfAllPassAlive) hist.endGameIfAllPassAlive(board); + hist.endGameIfNoLegalMoves(board); if(hist.isGameFinished) break; } @@ -263,17 +264,17 @@ void PlayUtils::initializeGameUsingPolicy( //Place black handicap stones, free placement -//Does NOT switch the initial player of the board history to white -void PlayUtils::playExtraBlack( +// Does NOT switch the initial player of the board history to white +vector PlayUtils::playExtraBlack( Search* bot, int numExtraBlack, Board& board, BoardHistory& hist, double temperature, - Rand& gameRand -) { + Rand& gameRand) { Player pla = P_BLACK; + std::vector handicapLocs; if(!hist.isGameFinished) { NNResultBuf buf; for(int i = 0; isetPosition(pla,board,hist); + return handicapLocs; } -void PlayUtils::placeFixedHandicap(Board& board, int n) { - int xSize = board.x_size; - int ySize = board.y_size; +vector PlayUtils::generateFixedHandicap(const Board& board, const int n) { + const int xSize = board.x_size; + const int ySize = board.y_size; if(xSize < 7 || ySize < 7) throw StringError("Board is too small for fixed handicap"); if((xSize % 2 == 0 || ySize % 2 == 0) && n > 4) @@ -311,8 +314,6 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { if(n > 9) throw StringError("Fixed handicap > 9 is not allowed"); - board = Board(xSize,ySize); - int xCoords[3]; //Corner, corner, side int yCoords[3]; //Corner, corner, side if(xSize <= 12) { xCoords[0] = 2; xCoords[1] = xSize-3; xCoords[2] = xSize/2; } @@ -320,8 +321,10 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { if(ySize <= 12) { yCoords[0] = 2; yCoords[1] = ySize-3; yCoords[2] = ySize/2; } else { yCoords[0] = 3; yCoords[1] = ySize-4; yCoords[2] = ySize/2; } - auto s = [&](int xi, int yi) { - board.setStone(Location::getLoc(xCoords[xi],yCoords[yi],board.x_size),P_BLACK); + vector locs; + + auto s = [&](const int xi, const int yi) { + locs.push_back(Location::getLoc(xCoords[xi],yCoords[yi],board.x_size)); }; if(n == 2) { s(0,1); s(1,0); } else if(n == 3) { s(0,1); s(1,0); s(0,0); } @@ -332,6 +335,8 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { else if(n == 8) { s(0,1); s(1,0); s(0,0); s(1,1); s(0,2); s(1,2); s(2,0); s(2,1); } else if(n == 9) { s(0,1); s(1,0); s(0,0); s(1,1); s(0,2); s(1,2); s(2,0); s(2,1); s(2,2); } else { ASSERT_UNREACHABLE; } + + return locs; } double PlayUtils::getHackedLCBForWinrate(const Search* search, const AnalysisData& data, Player pla) { @@ -948,15 +953,14 @@ PlayUtils::BenchmarkResults PlayUtils::benchmarkSearchOnPositionsAndPrint( Rand seedRand; Search* bot = new Search(params,nnEval,nnEval->getLogger(),Global::uint64ToString(seedRand.nextUInt64())); - //Ignore the SGF rules, except for komi. Just use Tromp-taylor. - Rules initialRules = Rules::getTrompTaylorish(); + //Ignore the SGF rules, except for Dots and komi. Just use Tromp-taylor in case of Go. + Rules initialRules = Rules::getDefaultOrTrompTaylorish(sgf.isDots); //Take the komi from the sgf, otherwise ignore the rules in the sgf initialRules.komi = sgf.getRulesOrFailAllowUnspecified(initialRules).komi; Board board; Player nextPla; - BoardHistory hist; - sgf.setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf.setupInitialBoardAndHist(initialRules, nextPla); int moveNum = 0; diff --git a/cpp/program/playutils.h b/cpp/program/playutils.h index c4da1c558..48f853056 100644 --- a/cpp/program/playutils.h +++ b/cpp/program/playutils.h @@ -10,17 +10,10 @@ namespace PlayUtils { //Use the given bot to play free handicap stones, modifying the board and hist in the process and setting the bot's position to it. //Does NOT switch the initial player of the board history to white - void playExtraBlack( - Search* bot, - int numExtraBlack, - Board& board, - BoardHistory& hist, - double temperature, - Rand& gameRand - ); + std::vector playExtraBlack(Search* bot, int numExtraBlack, Board& board, BoardHistory& hist, double temperature, Rand& gameRand); - //Set board to empty and place fixed handicap stones, raising an exception if invalid - void placeFixedHandicap(Board& board, int n); + // Generate fixed handicap stones, raising an exception if invalid + std::vector generateFixedHandicap(const Board& board, int n); ExtraBlackAndKomi chooseExtraBlackAndKomi( float base, float stdev, double allowIntegerProb, diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index a31295145..9b9feae68 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -115,22 +115,22 @@ vector Setup::initializeNNEvaluators( int nnYLen = std::max(defaultNNYLen,2); if(setupFor != SETUP_FOR_DISTRIBUTED) { if(cfg.contains("maxBoardXSizeForNNBuffer" + idxStr)) - nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardXSizeForNNBuffer")) - nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardSizeForNNBuffer" + idxStr)) - nnXLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardSizeForNNBuffer")) - nnXLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_X); if(cfg.contains("maxBoardYSizeForNNBuffer" + idxStr)) - nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardYSizeForNNBuffer")) - nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardSizeForNNBuffer" + idxStr)) - nnYLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardSizeForNNBuffer")) - nnYLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_Y); } bool requireExactNNLen = defaultRequireExactNNLen; @@ -302,6 +302,7 @@ vector Setup::initializeNNEvaluators( if(disableFP16) useFP16Mode = enabled_t::False; + bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); NNEvaluator* nnEval = new NNEvaluator( nnModelName, nnModelFile, @@ -324,7 +325,8 @@ vector Setup::initializeNNEvaluators( gpuIdxByServerThread, nnRandSeed, (forcedSymmetry >= 0 ? false : nnRandomize), - defaultSymmetry + defaultSymmetry, + dotsGame ); nnEval->spawnServerThreads(); @@ -883,72 +885,82 @@ Player Setup::parseReportAnalysisWinrates( throw StringError("Could not parse config value for reportAnalysisWinratesAs: " + sOrig); } -Rules Setup::loadSingleRules( - ConfigParser& cfg, - bool loadKomi -) { - Rules rules; +Rules Setup::loadSingleRules(ConfigParser& cfg, const bool loadKomi) { + const bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); + Rules rules = Rules::getDefault(dotsGame); if(cfg.contains("rules")) { - if(cfg.contains("koRule")) throw StringError("Cannot both specify 'rules' and individual rules like koRule"); - if(cfg.contains("scoringRule")) throw StringError("Cannot both specify 'rules' and individual rules like scoringRule"); + if(cfg.contains(START_POS_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + START_POS_KEY); + if(cfg.contains(START_POS_RANDOM_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + START_POS_RANDOM_KEY); if(cfg.contains("multiStoneSuicideLegal")) throw StringError("Cannot both specify 'rules' and individual rules like multiStoneSuicideLegal"); - if(cfg.contains("hasButton")) throw StringError("Cannot both specify 'rules' and individual rules like hasButton"); - if(cfg.contains("taxRule")) throw StringError("Cannot both specify 'rules' and individual rules like taxRule"); - if(cfg.contains("whiteHandicapBonus")) throw StringError("Cannot both specify 'rules' and individual rules like whiteHandicapBonus"); - if(cfg.contains("friendlyPassOk")) throw StringError("Cannot both specify 'rules' and individual rules like friendlyPassOk"); - if(cfg.contains("whiteBonusPerHandicapStone")) throw StringError("Cannot both specify 'rules' and individual rules like whiteBonusPerHandicapStone"); - rules = Rules::parseRules(cfg.getString("rules")); + if (dotsGame) { + if (cfg.contains(DOTS_CAPTURE_EMPTY_BASE_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + DOTS_CAPTURE_EMPTY_BASE_KEY); + } else { + if(cfg.contains("koRule")) throw StringError("Cannot both specify 'rules' and individual rules like koRule"); + if(cfg.contains("scoringRule")) throw StringError("Cannot both specify 'rules' and individual rules like scoringRule"); + if(cfg.contains("hasButton")) throw StringError("Cannot both specify 'rules' and individual rules like hasButton"); + if(cfg.contains("taxRule")) throw StringError("Cannot both specify 'rules' and individual rules like taxRule"); + if(cfg.contains("whiteHandicapBonus")) throw StringError("Cannot both specify 'rules' and individual rules like whiteHandicapBonus"); + if(cfg.contains("friendlyPassOk")) throw StringError("Cannot both specify 'rules' and individual rules like friendlyPassOk"); + if(cfg.contains("whiteBonusPerHandicapStone")) throw StringError("Cannot both specify 'rules' and individual rules like whiteBonusPerHandicapStone"); + } + + rules = Rules::parseRules(cfg.getString("rules"), cfg.getBoolOrDefault(DOTS_KEY, false)); } else { - string koRule = cfg.getString("koRule", Rules::koRuleStrings()); - string scoringRule = cfg.getString("scoringRule", Rules::scoringRuleStrings()); - bool multiStoneSuicideLegal = cfg.getBool("multiStoneSuicideLegal"); - bool hasButton = cfg.contains("hasButton") ? cfg.getBool("hasButton") : false; - float komi = 7.5f; - - rules.koRule = Rules::parseKoRule(koRule); - rules.scoringRule = Rules::parseScoringRule(scoringRule); - rules.multiStoneSuicideLegal = multiStoneSuicideLegal; - rules.hasButton = hasButton; - rules.komi = komi; - - if(cfg.contains("taxRule")) { - string taxRule = cfg.getString("taxRule", Rules::taxRuleStrings()); - rules.taxRule = Rules::parseTaxRule(taxRule); + if (cfg.contains(START_POS_KEY)) { + rules.startPos = Rules::parseStartPos(cfg.getString(START_POS_KEY)); } - else { - rules.taxRule = (rules.scoringRule == Rules::SCORING_TERRITORY ? Rules::TAX_SEKI : Rules::TAX_NONE); + if (cfg.contains(START_POS_RANDOM_KEY)) { + rules.startPosIsRandom = cfg.getBool(START_POS_RANDOM_KEY); } + rules.multiStoneSuicideLegal = cfg.getBoolOrDefault("multiStoneSuicideLegal", rules.multiStoneSuicideLegal); + + if (dotsGame) { + rules.dotsCaptureEmptyBases = cfg.getBoolOrDefault(DOTS_CAPTURE_EMPTY_BASE_KEY, rules.dotsCaptureEmptyBases); + } else { + rules.koRule = Rules::parseKoRule(cfg.getString("koRule", Rules::koRuleStrings())); + rules.scoringRule = Rules::parseScoringRule(cfg.getString("scoringRule", Rules::scoringRuleStrings())); + rules.hasButton = cfg.getBoolOrDefault("hasButton", false); + rules.komi = 7.5f; + + if(cfg.contains("taxRule")) { + string taxRule = cfg.getString("taxRule", Rules::taxRuleStrings()); + rules.taxRule = Rules::parseTaxRule(taxRule); + } + else { + rules.taxRule = (rules.scoringRule == Rules::SCORING_TERRITORY ? Rules::TAX_SEKI : Rules::TAX_NONE); + } - if(rules.hasButton && rules.scoringRule != Rules::SCORING_AREA) - throw StringError("Config specifies hasButton=true on a scoring system other than AREA"); - - //Also handles parsing of legacy option whiteBonusPerHandicapStone - if(cfg.contains("whiteBonusPerHandicapStone") && cfg.contains("whiteHandicapBonus")) - throw StringError("May specify only one of whiteBonusPerHandicapStone and whiteHandicapBonus in config"); - else if(cfg.contains("whiteHandicapBonus")) - rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(cfg.getString("whiteHandicapBonus", Rules::whiteHandicapBonusRuleStrings())); - else if(cfg.contains("whiteBonusPerHandicapStone")) { - int whiteBonusPerHandicapStone = cfg.getInt("whiteBonusPerHandicapStone",0,1); - if(whiteBonusPerHandicapStone == 0) - rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + if(rules.hasButton && rules.scoringRule != Rules::SCORING_AREA) + throw StringError("Config specifies hasButton=true on a scoring system other than AREA"); + + //Also handles parsing of legacy option whiteBonusPerHandicapStone + if(cfg.contains("whiteBonusPerHandicapStone") && cfg.contains("whiteHandicapBonus")) + throw StringError("May specify only one of whiteBonusPerHandicapStone and whiteHandicapBonus in config"); + else if(cfg.contains("whiteHandicapBonus")) + rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(cfg.getString("whiteHandicapBonus", Rules::whiteHandicapBonusRuleStrings())); + else if(cfg.contains("whiteBonusPerHandicapStone")) { + int whiteBonusPerHandicapStone = cfg.getInt("whiteBonusPerHandicapStone",0,1); + if(whiteBonusPerHandicapStone == 0) + rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + else + rules.whiteHandicapBonusRule = Rules::WHB_N; + } else - rules.whiteHandicapBonusRule = Rules::WHB_N; - } - else - rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + rules.whiteHandicapBonusRule = Rules::WHB_ZERO; - if(cfg.contains("friendlyPassOk")) { - rules.friendlyPassOk = cfg.getBool("friendlyPassOk"); - } + if(cfg.contains("friendlyPassOk")) { + rules.friendlyPassOk = cfg.getBool("friendlyPassOk"); + } - //Drop default komi to 6.5 for territory rules, and to 7.0 for button - if(rules.scoringRule == Rules::SCORING_TERRITORY) - rules.komi = 6.5f; - else if(rules.hasButton) - rules.komi = 7.0f; + //Drop default komi to 6.5 for territory rules, and to 7.0 for button + if(rules.scoringRule == Rules::SCORING_TERRITORY) + rules.komi = 6.5f; + else if(rules.hasButton) + rules.komi = 7.0f; + } } if(loadKomi) { @@ -965,12 +977,12 @@ bool Setup::loadDefaultBoardXYSize( int& defaultBoardYSizeRet ) { const int defaultBoardXSize = - cfg.contains("defaultBoardXSize") ? cfg.getInt("defaultBoardXSize",2,Board::MAX_LEN) : - cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN) : + cfg.contains("defaultBoardXSize") ? cfg.getInt("defaultBoardXSize",2,Board::MAX_LEN_X) : + cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN_X) : -1; const int defaultBoardYSize = - cfg.contains("defaultBoardYSize") ? cfg.getInt("defaultBoardYSize",2,Board::MAX_LEN) : - cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN) : + cfg.contains("defaultBoardYSize") ? cfg.getInt("defaultBoardYSize",2,Board::MAX_LEN_Y) : + cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN_Y) : -1; if((defaultBoardXSize == -1) != (defaultBoardYSize == -1)) logger.write("Warning: Config specified only one of defaultBoardXSize or defaultBoardYSize and no other board size parameter, ignoring it"); @@ -1022,7 +1034,7 @@ std::vector> Setup::loadAvoidSgfPatternBonusT double lambda = contains("PatternLambda") ? cfg.getDouble(find("PatternLambda"),0.0,1.0) : 1.0; int minTurnNumber = contains("PatternMinTurnNumber") ? cfg.getInt(find("PatternMinTurnNumber"),0,1000000) : 0; size_t maxFiles = contains("PatternMaxFiles") ? (size_t)cfg.getInt(find("PatternMaxFiles"),1,1000000) : 1000000; - vector allowedPlayerNames = contains("PatternAllowedNames") ? cfg.getStringsNonEmptyTrim(find("PatternAllowedNames")) : vector(); + vector allowedPlayerNames = contains("PatternAllowedNames") ? cfg.getStrings(find("PatternAllowedNames"), {}, true) : vector(); vector sgfDirs = cfg.getStrings(find("PatternDirs")); if(patternBonusTable == nullptr) patternBonusTable = std::make_unique(); @@ -1119,7 +1131,7 @@ std::unique_ptr Setup::loadAndPruneAutoPatternBonusTables(Con bool suc = Global::tryStringToInt(pieces[0],boardXSize) && Global::tryStringToInt(pieces[1],boardYSize); if(!suc) continue; - if(boardXSize < 2 || boardXSize > Board::MAX_LEN || boardYSize < 2 || boardYSize > Board::MAX_LEN) + if(boardXSize < 2 || boardXSize > Board::MAX_LEN_X || boardYSize < 2 || boardYSize > Board::MAX_LEN_Y) continue; string dirPath = baseDir + "/" + dirName; diff --git a/cpp/search/asyncbot.cpp b/cpp/search/asyncbot.cpp index 0c18a5ef5..393a2cf3f 100644 --- a/cpp/search/asyncbot.cpp +++ b/cpp/search/asyncbot.cpp @@ -43,7 +43,8 @@ AsyncBot::AsyncBot( NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* l, - const string& randSeed + const string& randSeed, + const Rules& rules ) :search(NULL), controlMutex(),threadWaitingToSearch(),userWaitingForStop(),searchThread(), @@ -54,7 +55,7 @@ AsyncBot::AsyncBot( analyzeCallback(), searchBegunCallback() { - search = new Search(params,nnEval,humanEval,l,randSeed); + search = new Search(params,nnEval,l,randSeed,humanEval,rules); searchThread = std::thread(searchThreadLoop,this,l); } diff --git a/cpp/search/asyncbot.h b/cpp/search/asyncbot.h index ae94873ce..706574811 100644 --- a/cpp/search/asyncbot.h +++ b/cpp/search/asyncbot.h @@ -16,7 +16,8 @@ class AsyncBot { NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* logger, - const std::string& randSeed + const std::string& randSeed, + const Rules& rules ); ~AsyncBot(); diff --git a/cpp/search/localpattern.cpp b/cpp/search/localpattern.cpp index 8d075a06c..9081fdf37 100644 --- a/cpp/search/localpattern.cpp +++ b/cpp/search/localpattern.cpp @@ -48,15 +48,19 @@ void LocalPatternHasher::init(int x, int y, Rand& rand) { LocalPatternHasher::~LocalPatternHasher() { } - Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) const { Hash128 hash = zobristPla[pla]; - if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - const int dxi = board.adj_offsets[2]; - const int dyi = board.adj_offsets[3]; - assert(dxi == 1); - assert(dyi == board.x_size+1); + if(loc != Board::PASS_LOC && loc != Board::NULL_LOC && loc != Board::RESIGN_LOC) { + vector bases; + if (board.isDots()) { + vector captures; + // TODO: implement fast version for Dots + // board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + } + + const int dxi = 1; + const int dyi = board.x_size+1; int xRadius = xSize/2; int yRadius = ySize/2; @@ -74,9 +78,8 @@ Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) con int y2 = dy + yCenter; int x2 = dx + xCenter; int xy2 = y2 * xSize + x2; - hash ^= zobristLocalPattern[(int)board.colors[loc2] * xSize * ySize + xy2]; - if((board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) && board.getNumLiberties(loc2) == 1) - hash ^= zobristAtari[xy2]; + + updateHash(hash, board, bases, loc2, xy2, false); } } } @@ -88,11 +91,16 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p Player symPla = flipColors ? getOpp(pla) : pla; Hash128 hash = zobristPla[symPla]; - if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - const int dxi = board.adj_offsets[2]; - const int dyi = board.adj_offsets[3]; - assert(dxi == 1); - assert(dyi == board.x_size+1); + if(loc != Board::PASS_LOC && loc != Board::NULL_LOC && loc != Board::RESIGN_LOC) { + vector bases; + if (board.isDots()) { + vector captures; + // TODO: implement fast version for Dots + // board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + } + + const int dxi = 1; + const int dyi = board.x_size+1; int xRadius = xSize/2; int yRadius = ySize/2; @@ -125,18 +133,37 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p symXY2 = symY2 * xSize + symX2; } - int symColor; - if(board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) - symColor = (int)(flipColors ? getOpp(board.colors[loc2]) : board.colors[loc2]); - else - symColor = (int)board.colors[loc2]; - - hash ^= zobristLocalPattern[symColor * xSize * ySize + symXY2]; - if((board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) && board.getNumLiberties(loc2) == 1) - hash ^= zobristAtari[symXY2]; + updateHash(hash, board, bases, loc2, symXY2, flipColors); } } } return hash; } + +void LocalPatternHasher::updateHash( + Hash128& hash, + const Board& board, + const vector& bases, + const Loc loc, + const int patternXY, + const bool flipColors) const { + const Color colorAtLoc = board.getColor(loc); + Color newColor; + if (flipColors && (colorAtLoc == P_BLACK || colorAtLoc == P_WHITE)) { + newColor = getOpp(colorAtLoc); + } else { + newColor = colorAtLoc; + } + + hash ^= zobristLocalPattern[static_cast(newColor) * xSize * ySize + patternXY]; + + bool addAtariHash = false; + if(board.isDots()) { + addAtariHash = !bases.empty() && bases[loc] != C_EMPTY; + } else { + addAtariHash = (colorAtLoc == P_BLACK || colorAtLoc == P_WHITE) && board.getNumLiberties(loc) == 1; + } + if(addAtariHash) + hash ^= zobristAtari[patternXY]; +} \ No newline at end of file diff --git a/cpp/search/localpattern.h b/cpp/search/localpattern.h index 36ed0544b..67abc3b3a 100644 --- a/cpp/search/localpattern.h +++ b/cpp/search/localpattern.h @@ -23,6 +23,15 @@ struct LocalPatternHasher { //Returns the hash that would occur if symmetry were applied to both board and loc. //So basically, the only thing that changes is the zobrist indexing. Hash128 getHashWithSym(const Board& board, Loc loc, Player pla, int symmetry, bool flipColors) const; + +private: + void updateHash( + Hash128& hash, + const Board& board, + const std::vector& bases, + Loc loc, + int patternXY, + bool flipColors) const; }; #endif //SEARCH_LOCALPATTERN_H diff --git a/cpp/search/patternbonustable.cpp b/cpp/search/patternbonustable.cpp index 6367ecce3..6fa18d668 100644 --- a/cpp/search/patternbonustable.cpp +++ b/cpp/search/patternbonustable.cpp @@ -52,7 +52,7 @@ Hash128 PatternBonusTable::getHash(Player pla, Loc moveLoc, const Board& board) //We don't want to over-trigger this on a ko that repeats the same pattern over and over //So we just disallow this on ko fight //Also no bonuses for passing. - if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || board.wouldBeKoCapture(moveLoc,pla)) + if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || moveLoc == Board::RESIGN_LOC || board.wouldBeKoCapture(moveLoc,pla)) return Hash128(); Hash128 hash = patternHasher.getHash(board,moveLoc,pla); @@ -87,7 +87,7 @@ void PatternBonusTable::addBonus(Player pla, Loc moveLoc, const Board& board, do //We don't want to over-trigger this on a ko that repeats the same pattern over and over //So we just disallow this on ko fight //Also no bonuses for passing. - if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || board.wouldBeKoCapture(moveLoc,pla)) + if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || moveLoc == Board::RESIGN_LOC || board.wouldBeKoCapture(moveLoc,pla)) return; Hash128 hash = patternHasher.getHashWithSym(board,moveLoc,pla,symmetry,flipColors); @@ -258,7 +258,7 @@ void PatternBonusTable::avoidRepeatedPosMovesAndDeleteExcessFiles( turnNumber < minTurnNumber || turnNumber > maxTurnNumber || posSample.moves.size() != 0 || // Right now auto pattern avoid expects moveless records - !posSample.board.isLegal(posSample.hintLoc, posSample.nextPla, isMultiStoneSuicideLegal) + !posSample.board.isLegal(posSample.hintLoc, posSample.nextPla, isMultiStoneSuicideLegal, false) ) { numPosesInvalid += 1; continue; diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index b0bf240b4..d5bcfef77 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -65,71 +65,73 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; -Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed) - :Search(params,nnEval,NULL,lg,rSeed) -{} -Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed) - :rootPla(P_BLACK), - rootBoard(), - rootHistory(), - rootGraphHash(), - rootHintLoc(Board::NULL_LOC), - avoidMoveUntilByLocBlack(),avoidMoveUntilByLocWhite(),avoidMoveUntilRescaleRoot(false), - rootSymmetries(), - rootPruneOnlySymmetries(), - rootSafeArea(NULL), - recentScoreCenter(0.0), - mirroringPla(C_EMPTY), - mirrorAdvantage(0.0), - mirrorCenterSymmetryError(1e10), - alwaysIncludeOwnerMap(false), - searchParams(params),numSearchesBegun(0),searchNodeAge(0), - plaThatSearchIsFor(C_EMPTY),plaThatSearchIsForLastSearch(C_EMPTY), - lastSearchNumPlayouts(0), - effectiveSearchTimeCarriedOver(0.0), - randSeed(rSeed), - rootKoHashTable(NULL), - valueWeightDistribution(NULL), - patternBonusTable(NULL), - externalPatternBonusTable(nullptr), - evalCache(nullptr), - nonSearchRand(rSeed + string("$nonSearchRand")), - logger(lg), - nnEvaluator(nnEval), - humanEvaluator(humanEval), - nnXLen(), - nnYLen(), - policySize(), - rootNode(NULL), - nodeTable(NULL), - mutexPool(NULL), - subtreeValueBiasTable(NULL), - numThreadsSpawned(0), - threads(NULL), - threadTasks(NULL), - threadTasksRemaining(NULL), - oldNNOutputsToCleanUpMutex(), - oldNNOutputsToCleanUp() -{ +Search::Search(const SearchParams ¶ms, NNEvaluator* nnEval, Logger* lg, const string& newRandSeed, NNEvaluator* humanEval, const Rules& rules) + : rootPla(P_BLACK), + rootBoard(rules), + rootHistory(rules), + rootGraphHash(), + rootHintLoc(Board::NULL_LOC), + avoidMoveUntilByLocBlack(), avoidMoveUntilByLocWhite(), avoidMoveUntilRescaleRoot(false), + rootSymmetries(), + rootPruneOnlySymmetries(), + rootSafeArea(NULL), + recentScoreCenter(0.0), + mirroringPla(C_EMPTY), + mirrorAdvantage(0.0), + mirrorCenterSymmetryError(1e10), + alwaysIncludeOwnerMap(false), + searchParams(params), numSearchesBegun(0), searchNodeAge(0), + plaThatSearchIsFor(C_EMPTY), plaThatSearchIsForLastSearch(C_EMPTY), + lastSearchNumPlayouts(0), + effectiveSearchTimeCarriedOver(0.0), + randSeed(newRandSeed), + rootKoHashTable(NULL), + valueWeightDistribution(NULL), + normToTApproxZ(0), + patternBonusTable(NULL), + externalPatternBonusTable(nullptr), + evalCache(nullptr), + nonSearchRand(newRandSeed + string("$nonSearchRand")), + logger(lg), + nnEvaluator(nnEval), + humanEvaluator(humanEval), + nnXLen(), + nnYLen(), + policySize(), + rootNode(NULL), + nodeTable(NULL), + mutexPool(NULL), + subtreeValueBiasTable(NULL), + numThreadsSpawned(0), + threads(NULL), + threadTasks(NULL), + threadTasksRemaining(NULL), + oldNNOutputsToCleanUpMutex(), + oldNNOutputsToCleanUp() { assert(logger != NULL); nnXLen = nnEval->getNNXLen(); nnYLen = nnEval->getNNYLen(); - assert(nnXLen > 0 && nnXLen <= NNPos::MAX_BOARD_LEN); - assert(nnYLen > 0 && nnYLen <= NNPos::MAX_BOARD_LEN); - policySize = NNPos::getPolicySize(nnXLen,nnYLen); + assert(nnXLen > 0 && nnXLen <= NNPos::MAX_BOARD_LEN_X); + assert(nnYLen > 0 && nnYLen <= NNPos::MAX_BOARD_LEN_Y); + policySize = NNPos::getPolicySize(nnXLen, nnYLen); - if(humanEvaluator != NULL) { - if(humanEvaluator->getNNXLen() != nnXLen || humanEvaluator->getNNYLen() != nnYLen) + if (humanEvaluator != NULL) { + if (humanEvaluator->getNNXLen() != nnXLen || humanEvaluator->getNNYLen() != nnYLen) throw StringError("Search::init - humanEval has different nnXLen or nnYLen"); } - rootKoHashTable = new KoHashTable(); + assert(rootHistory.rules.isDots == rootBoard.isDots()); + rootHistory.clear(rootBoard, rootPla, rules, 0); + if (!rules.isDots) { + rootKoHashTable = new KoHashTable(); + rootKoHashTable->recompute(rootHistory); + } rootSafeArea = new Color[Board::MAX_ARR_SIZE]; valueWeightDistribution = new DistributionTable( - [](double z) { return FancyMath::tdistpdf(z,VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, - [](double z) { return FancyMath::tdistcdf(z,VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, + [](double z) { return FancyMath::tdistpdf(z, VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, + [](double z) { return FancyMath::tdistcdf(z, VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, -50.0, 50.0, 2000 @@ -138,9 +140,6 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, rootNode = NULL; nodeTable = new SearchNodeTable(params.nodeTableShardsPowerOfTwo); mutexPool = new MutexPool(nodeTable->mutexPool->getNumMutexes()); - - rootHistory.clear(rootBoard,rootPla,Rules(),0); - rootKoHashTable->recompute(rootHistory); } Search::~Search() { @@ -181,7 +180,10 @@ void Search::setPosition(Player pla, const Board& board, const BoardHistory& his plaThatSearchIsFor = C_EMPTY; rootBoard = board; rootHistory = history; - rootKoHashTable->recompute(rootHistory); + assert(rootHistory.rules.isDots == rootBoard.isDots()); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } avoidMoveUntilByLocBlack.clear(); avoidMoveUntilByLocWhite.clear(); } @@ -197,7 +199,9 @@ void Search::setPlayerAndClearHistory(Player pla) { rootHistory.clear(rootBoard,rootPla,rules,rootHistory.encorePhase); rootHistory.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); - rootKoHashTable->recompute(rootHistory); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } //If changing the player alone, don't clear these, leave the user's setting - the user may have tried //to adjust the player or will be calling runWholeSearchAndGetMove with a different player and will @@ -336,7 +340,9 @@ bool Search::makeMove(Loc moveLoc, Player movePla, bool preventEncore) { //Compute these first so we can know if we need to set forceNonTerminal below. rootHistory.makeBoardMoveAssumeLegal(rootBoard,moveLoc,rootPla,rootKoHashTable,preventEncore); rootPla = getOpp(rootPla); - rootKoHashTable->recompute(rootHistory); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } if(rootNode != NULL) { SearchNode* child = NULL; diff --git a/cpp/search/search.h b/cpp/search/search.h index 562af59fd..431dfadfc 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -182,18 +182,12 @@ struct Search { //Note - randSeed controls a few things in the search, but a lot of the randomness actually comes from //random symmetries of the neural net evaluations, see nneval.h Search( - SearchParams params, + const SearchParams ¶ms, NNEvaluator* nnEval, Logger* logger, - const std::string& randSeed - ); - Search( - SearchParams params, - NNEvaluator* nnEval, - NNEvaluator* humanEval, - Logger* logger, - const std::string& randSeed - ); + const std::string& newRandSeed, + NNEvaluator* humanEval = nullptr, + const Rules& rules = Rules::DEFAULT_GO); ~Search(); Search(const Search&) = delete; diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 59e9b7483..6b1ced47d 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -639,5 +639,4 @@ void Search::selectBestChildToDescend( thread.shouldCountPlayout = false; } } - } diff --git a/cpp/search/searchhelpers.cpp b/cpp/search/searchhelpers.cpp index 060e9875a..97751e444 100644 --- a/cpp/search/searchhelpers.cpp +++ b/cpp/search/searchhelpers.cpp @@ -308,6 +308,11 @@ double Search::getUtilityFromNN(const NNOutput& nnOutput) const { bool Search::isAllowedRootMove(Loc moveLoc) const { + if (rootHistory.rules.isDots) { + // Not actual for Dots because Pass (aka Grounding) in the latest possible move in the game + // Maybe it makes sense to allow Grounding if only it wins the game? + return true; + } assert(moveLoc == Board::PASS_LOC || rootBoard.isOnBoard(moveLoc)); //A bad situation that can happen that unnecessarily prolongs training games is where one player diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index 4096e1373..26206a70c 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -456,7 +456,7 @@ json SearchParams::changeableParametersToJson() const { // Special handling in GTP ret["playoutDoublingAdvantage"] = playoutDoublingAdvantage; - ret["playoutDoublingAdvantagePla"] = PlayerIO::playerToStringShort(playoutDoublingAdvantagePla); + ret["playoutDoublingAdvantagePla"] = PlayerIO::playerToStringShort(playoutDoublingAdvantagePla, false); // TODO: Fix for Dots // Special handling in GTP // ret["avoidRepeatedPatternUtility"] = avoidRepeatedPatternUtility; diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 468a018fc..9fc5e0b8a 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -1335,7 +1335,7 @@ void Search::printTreeHelper( return; } if((options.alsoBranch_ && depth == 0) || (!options.alsoBranch_ && depth == options.branch_.size())) { - out << "---" << PlayerIO::playerToString(node.nextPla) << "(" << (node.nextPla == perspectiveToUse ? "^" : "v") << ")---" << endl; + out << "---" << PlayerIO::playerToString(node.nextPla,rootBoard.isDots()) << "(" << (node.nextPla == perspectiveToUse ? "^" : "v") << ")---" << endl; } vector analysisData; @@ -2155,7 +2155,7 @@ bool Search::getAnalysisJson( } rootInfo["thisHash"] = Global::uint64ToHexString(thisHash.hash1) + Global::uint64ToHexString(thisHash.hash0); rootInfo["symHash"] = Global::uint64ToHexString(symHash.hash1) + Global::uint64ToHexString(symHash.hash0); - rootInfo["currentPlayer"] = PlayerIO::playerToStringShort(rootPla); + rootInfo["currentPlayer"] = PlayerIO::playerToStringShort(rootPla, board.isDots()); ret["rootInfo"] = rootInfo; } diff --git a/cpp/tests/testboardarea.cpp b/cpp/tests/testboardarea.cpp index 7248b3f29..ad32835b0 100644 --- a/cpp/tests/testboardarea.cpp +++ b/cpp/tests/testboardarea.cpp @@ -19,7 +19,7 @@ void Tests::runBoardAreaTests() { bool safeBigTerritories = safeBigTerritoriesBuf[mode/2]; bool unsafeBigTerritories = unsafeBigTerritoriesBuf[mode/2]; bool nonPassAliveStones = nonPassAliveStonesBuf[mode/2]; - Board copy(board); + const Board& copy(board); copy.calculateArea(result,nonPassAliveStones,safeBigTerritories,unsafeBigTerritories,multiStoneSuicideLegal); out << "Safe big territories " << safeBigTerritories << " " << "Unsafe big territories " << unsafeBigTerritories << " " diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index fa62b94e5..06535e8ee 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -13,10 +13,10 @@ void Tests::runBoardIOTests() { //============================================================================ { const char* name = "Location parse test"; - auto testLoc = [&out](const char* s, int xSize, int ySize) { + auto testLoc = [&out](const char* s, int xSize, int ySize, bool isDots) { try { Loc loc = Location::ofString(s,xSize,ySize); - out << s << " " << Location::toString(loc,xSize,ySize) << " x " << Location::getX(loc,xSize) << " y " << Location::getY(loc,xSize) << endl; + out << s << " " << Location::toString(loc, xSize, ySize, isDots) << " x " << Location::getX(loc,xSize) << " y " << Location::getY(loc,xSize) << endl; } catch(const StringError& e) { out << e.what() << endl; @@ -34,27 +34,28 @@ void Tests::runBoardIOTests() { out << "----------------------------------" << endl; out << xSize << " " << ySize << endl; - testLoc("A1",xSize,ySize); - testLoc("A0",xSize,ySize); - testLoc("B2",xSize,ySize); - testLoc("b2",xSize,ySize); - testLoc("A",xSize,ySize); - testLoc("B",xSize,ySize); - testLoc("1",xSize,ySize); - testLoc("pass",xSize,ySize); - testLoc("H9",xSize,ySize); - testLoc("I9",xSize,ySize); - testLoc("J9",xSize,ySize); - testLoc("J10",xSize,ySize); - testLoc("K8",xSize,ySize); - testLoc("k19",xSize,ySize); - testLoc("a22",xSize,ySize); - testLoc("y1",xSize,ySize); - testLoc("z1",xSize,ySize); - testLoc("aa1",xSize,ySize); - testLoc("AA26",xSize,ySize); - testLoc("AZ26",xSize,ySize); - testLoc("BC50",xSize,ySize); + testLoc("A1",xSize,ySize,false); + testLoc("A0",xSize,ySize,false); + testLoc("B2",xSize,ySize,false); + testLoc("b2",xSize,ySize,false); + testLoc("A",xSize,ySize,false); + testLoc("B",xSize,ySize,false); + testLoc("1",xSize,ySize,false); + testLoc("pass",xSize,ySize,false); + testLoc("ground",xSize,ySize,true); + testLoc("H9",xSize,ySize,false); + testLoc("I9",xSize,ySize,false); + testLoc("J9",xSize,ySize,false); + testLoc("J10",xSize,ySize,false); + testLoc("K8",xSize,ySize,false); + testLoc("k19",xSize,ySize,false); + testLoc("a22",xSize,ySize,false); + testLoc("y1",xSize,ySize,false); + testLoc("z1",xSize,ySize,false); + testLoc("aa1",xSize,ySize,false); + testLoc("AA26",xSize,ySize,false); + testLoc("AZ26",xSize,ySize,false); + testLoc("BC50",xSize,ySize,false); } } @@ -69,6 +70,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 0 Could not parse board location: I9 J9 J9 x 8 y 0 @@ -92,6 +94,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -115,6 +118,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 0 Could not parse board location: I9 J9 J9 x 8 y 0 @@ -138,6 +142,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -161,6 +166,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -184,6 +190,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -207,6 +214,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -230,6 +238,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 61 Could not parse board location: I9 J9 J9 x 8 y 61 @@ -253,6 +262,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -276,6 +286,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 61 Could not parse board location: I9 J9 J9 x 8 y 61 @@ -728,7 +739,7 @@ After white //============================================================================ { const char* name = "Distance"; - Board board(17,12); + Board board(17,12,Rules::DEFAULT_GO); auto testDistance = [&](int x0, int y0, int x1, int y1) { out << "distance (" << x0 << "," << y0 << ") (" << x1 << "," << y1 << ") = " << @@ -1901,15 +1912,16 @@ void Tests::runBoardUndoTest() { int suicideCount = 0; int koCaptureCount = 0; int passCount = 0; + int resignCount = 0; int regularMoveCount = 0; auto run = [&](const Board& startBoard, bool multiStoneSuicideLegal) { static const int steps = 1000; - Board* boards = new Board[steps+1]; + vector boards; Board::MoveRecord records[steps]; - boards[0] = startBoard; + boards.push_back(startBoard); for(int n = 1; n <= steps; n++) { - boards[n] = boards[n-1]; + boards.push_back(boards[n-1]); Loc loc; Player pla; while(true) { @@ -1917,7 +1929,7 @@ void Tests::runBoardUndoTest() { //Maximum range of board location values when 19x19: int numLocs = (19+1)*(19+2)+1; loc = (Loc)rand.nextUInt(numLocs); - if(boards[n].isLegal(loc,pla,multiStoneSuicideLegal)) + if(boards[n].isLegal(loc, pla, multiStoneSuicideLegal, false)) break; } @@ -1925,6 +1937,8 @@ void Tests::runBoardUndoTest() { if(loc == Board::PASS_LOC) passCount++; + else if(loc == Board::RESIGN_LOC) + resignCount++; else if(boards[n-1].isSuicide(loc,pla)) suicideCount++; else { @@ -1940,26 +1954,27 @@ void Tests::runBoardUndoTest() { testAssert(boardsSeemEqual(boards[n],board)); board.checkConsistency(); } - delete[] boards; }; - run(Board(19,19),true); - run(Board(4,4),true); - run(Board(4,4),false); + run(Board(19,19,Rules::DEFAULT_GO),true); + run(Board(4,4,Rules::DEFAULT_GO),true); + run(Board(4,4,Rules::DEFAULT_GO),false); ostringstream out; out << endl; out << "regularMoveCount " << regularMoveCount << endl; out << "passCount " << passCount << endl; + out << "resignCount " << resignCount << endl; out << "koCaptureCount " << koCaptureCount << endl; out << "suicideCount " << suicideCount << endl; string expected = R"%%( -regularMoveCount 2446 -passCount 475 -koCaptureCount 24 -suicideCount 79 +regularMoveCount 2116 +passCount 376 +resignCount 443 +koCaptureCount 19 +suicideCount 65 )%%"; expect("Board undo test move counts",out,expected); @@ -1968,7 +1983,7 @@ suicideCount 79 void Tests::runBoardHandicapTest() { cout << "Running board handicap test" << endl; { - Board board = Board(19,19); + Board board = Board(19,19,Rules::DEFAULT_GO); Player nextPla = P_BLACK; Rules rules = Rules::parseRules("chinese"); BoardHistory hist(board,nextPla,rules,0); @@ -1990,9 +2005,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("chinese"); + const Rules rules = Rules::parseRules("chinese"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2009,9 +2024,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("aga"); + const Rules rules = Rules::parseRules("aga"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2028,9 +2043,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("aga"); + const Rules rules = Rules::parseRules("aga"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2059,9 +2074,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("chinese"); + const Rules rules = Rules::parseRules("chinese"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2094,18 +2109,20 @@ void Tests::runBoardStressTest() { { Rand rand("runBoardStressTests"); static const int numBoards = 4; - Board boards[numBoards]; - boards[0] = Board(); - boards[1] = Board(9,16); - boards[2] = Board(13,7); - boards[3] = Board(4,4); + vector boards; + Rules rules = Rules::DEFAULT_GO; + boards.emplace_back(Board::DEFAULT_LEN_X, Board::DEFAULT_LEN_Y, rules); + boards.emplace_back(9,16,rules); + boards.emplace_back(13,7,rules); + boards.emplace_back(4,4,rules); bool multiStoneSuicideLegal[4] = {false,false,true,false}; - Board copies[numBoards]; + vector copies; Player pla = C_BLACK; int suicideCount = 0; int koBanCount = 0; int koCaptureCount = 0; int passCount = 0; + int resignCount = 0; int regularMoveCount = 0; for(int n = 0; n < 20000; n++) { Loc locs[numBoards]; @@ -2132,13 +2149,14 @@ void Tests::runBoardStressTest() { } } + copies.clear(); for(int i = 0; i placements; for(int i = 0; i<1000; i++) { Loc loc = Location::getLoc(rand.nextUInt(board.x_size),rand.nextUInt(board.y_size),board.x_size); Player pla = rand.nextBool(0.5) ? P_BLACK : P_WHITE; - if(board.isLegal(loc,pla,true)) { + if(board.isLegal(loc, pla, true, false)) { placements.push_back(Move(loc,pla)); bool anyCaps = board.wouldBeCapture(loc,pla) || board.isSuicide(loc,pla); board.playMoveAssumeLegal(loc,pla); - Board copy(board.x_size,board.y_size); + Board copy(board.x_size,board.y_size,board.rules); bool suc = copy.setStonesFailIfNoLibs(placements); testAssert(suc == !anyCaps); copy.checkConsistency(); @@ -2368,7 +2391,7 @@ Caps 4420 4335 { Rand rand("runBoardSetStoneTests3"); for(int rep = 0; rep<1000; rep++) { - Board board(1 + rand.nextUInt(18), 1 + rand.nextUInt(18)); + Board board(1 + rand.nextUInt(18), 1 + rand.nextUInt(18),Rules::DEFAULT_GO); for(int i = 0; i<300; i++) { Loc loc = Location::getLoc(rand.nextUInt(board.x_size),rand.nextUInt(board.y_size),board.x_size); Color color = rand.nextBool(0.25) ? C_EMPTY : rand.nextBool(0.5) ? P_BLACK : P_WHITE; @@ -2384,7 +2407,7 @@ Caps 4420 4335 placements.push_back(Move(loc,color)); if(prevPlacedLocs.find(loc) != prevPlacedLocs.end()) { - Board copy(board.x_size,board.y_size); + Board copy(board.x_size,board.y_size,board.rules); bool suc = copy.setStonesFailIfNoLibs(placements); testAssert(!suc); placements.pop_back(); @@ -2507,9 +2530,9 @@ oxxxxx.xo //if(rep < 100) // hist.printDebugInfo(cout,board); - testAssert(boardCopy.isEqualForTesting(board, true, true)); - testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0), true, true)); - testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0), true, true)); + testAssert(boardCopy.isEqualForTesting(board)); + testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0))); + testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0))); testAssert(BoardHistory::getSituationRulesAndKoHash(boardCopy,histCopy,pla,drawEquivalentWinsForWhite) == hist.getSituationRulesAndKoHash(board,hist,pla,drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite)); diff --git a/cpp/tests/testbook.cpp b/cpp/tests/testbook.cpp index d3b2bf6b4..deaa60b09 100644 --- a/cpp/tests/testbook.cpp +++ b/cpp/tests/testbook.cpp @@ -27,8 +27,8 @@ static void corruptNodes(Book* book, std::vector& newAndChangedNode void Tests::runBookTests() { cout << "Running book tests" << endl; - Board initialBoard(4,4); Rules rules = Rules::parseRules("japanese"); + Board initialBoard(4,4,rules); Player initialPla = P_BLACK; int repBound = 9; diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp new file mode 100644 index 000000000..50553c4dd --- /dev/null +++ b/cpp/tests/testdotsbasic.cpp @@ -0,0 +1,861 @@ +#include "../tests/tests.h" +#include "../tests/testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace TestCommon; + +static void checkDotsField(const string& description, const string& input, + const std::function& check, + const bool suicide = Rules::DEFAULT_DOTS.multiStoneSuicideLegal, + const bool captureEmptyBases = Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, + const bool freeCapturedDots = Rules::DEFAULT_DOTS.dotsFreeCapturedDots) { + cout << " " << description << endl; + + auto moveRecords = vector(); + + Board initialBoard = parseDotsField(input, false, suicide, captureEmptyBases, freeCapturedDots, {}); + + Board board = Board(initialBoard); + + BoardWithMoveRecords boardWithMoveRecords = BoardWithMoveRecords(board, moveRecords); + check(boardWithMoveRecords); + + while (!moveRecords.empty()) { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } + testAssert(initialBoard.isEqualForTesting(board)); +} + +void Tests::runDotsFieldTests() { + cout << "Running dots basic tests: " << endl; + + checkDotsField("Simple capturing", + R"( +.x. +xox +... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsField("Capturing with empty loc inside", + R"( +.oo. +ox.. +.oo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(boardWithMoveRecords.isLegal(2, 1, P_BLACK)); + testAssert(boardWithMoveRecords.isLegal(2, 1, P_WHITE)); + + boardWithMoveRecords.playMove(3, 1, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(!boardWithMoveRecords.isLegal(2, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isLegal(2, 1, P_WHITE)); +}); + + checkDotsField("Triple capture", + R"( +.x.x. +xo.ox +.xox. +..x.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 1, P_BLACK); + testAssert(3 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsField("Base inside base inside base", + R"( +.xxxxxxx. +x..ooo..x +x.o.x.o.x +x.oxoxo.x +x.o...o.x +x..o.o..x +.xxx.xxx. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(4, 4, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 5, P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(4 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 6, P_BLACK); + testAssert(13 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +}); + + /*checkDotsField("Base inside base inside base don't free captured dots", + R"( +.xxxxxxxxx.. +x..oooooo.x. +x.o.xx...o.x +x.oxo.xo.o.x +x.o.x.o..o.x +x..o....ox.x +x...o.oo...x +.xxxx.xxxxx. +)", true, false, [](const BoardWithMoveRecords& boardWithMoveRecords) { +boardWithMoveRecords.playMove(5, 4, P_BLACK); +testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + +boardWithMoveRecords.playMove(5, 6, P_WHITE); +testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); // Don't free the captured dot +testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); // Ignore owned color dots + +boardWithMoveRecords.playMove(5, 7, P_BLACK); +testAssert(21 == boardWithMoveRecords.board.numWhiteCaptures); // Don't count already counted dots +testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); // Don't free the captured dot +});*/ + + checkDotsField("Empty bases and suicide", + R"( +.x..o. +x.xo.o +.x..o. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + // Suicide move is not capture + testAssert(!boardWithMoveRecords.wouldBeCapture(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.wouldBeCapture(1, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.wouldBeCapture(4, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.wouldBeCapture(4, 1, P_BLACK)); + + testAssert(boardWithMoveRecords.isSuicide(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_BLACK)); + boardWithMoveRecords.playMove(1, 1, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(boardWithMoveRecords.isSuicide(4, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_WHITE)); + boardWithMoveRecords.playMove(4, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); +}); + + checkDotsField("Empty bases when they are allowed", + R"( +.x..o. +x.xo.o +...... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + boardWithMoveRecords.playMove(4, 2, P_WHITE); + + // Suicide is not possible in this mode + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_WHITE)); + + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +}, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, true, Rules::DEFAULT_DOTS.dotsFreeCapturedDots); + + checkDotsField("Capture wins suicide", + R"( +.xo. +xo.o +.xo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(!boardWithMoveRecords.isSuicide(2, 1, P_BLACK)); + boardWithMoveRecords.playMove(2, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsField("Single dot doesn't break searching inside empty base", + R"( +.oooo. +o....o +o.o..o +o....o +.oooo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(4, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + }); + + checkDotsField("Ignored already surrounded territory", + R"( +..xxx... +.x...x.. +x..x..x. +x.x.x..x +x..x..x. +.x...x.. +..xxx... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playMove(6, 3, P_WHITE); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsField("Invalidation of empty base locations", + R"( +.oox. +o..ox +.oox. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 1, P_BLACK); + boardWithMoveRecords.playMove(1, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsField("Invalidation of empty base locations ignoring borders", + R"( +..xxx.... +.x...x... +x..x..xo. +x.x.x..xo +x..x..xo. +.x...x... +..xxx.... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(6, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(1, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(3, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsField("Dangling dots removing", + R"( +.xx.xx. +x..xo.x +x.x.x.x +x..x..x +.x...x. +..x.x.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 5, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(!boardWithMoveRecords.isLegal(3, 2, P_BLACK)); + testAssert(!boardWithMoveRecords.isLegal(3, 2, P_WHITE)); + }); + + checkDotsField("Recalculate square during dangling dots removing", + R"( +.ooo.. +o...o. +o.o..o +..xo.o +o.o..o +o...o. +.ooo.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 3, P_BLACK); + testAssert(2 == boardWithMoveRecords.board.numBlackCaptures); + }); + + checkDotsField("Base sorting by size", + R"( +..xxx.. +.x...x. +x..x..x +x.xox.x +x.....x +.xx.xx. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 4, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playMove(4, 1, P_WHITE); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsField("Number of legal moves", + R"( +.... +.... +.... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { +testAssert(12 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); +}); + + checkDotsField("No legal moves", + R"( +xxxx +xo.x +xx.x +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); + }, true); + + checkDotsField("Don't rely on legal moves number if suicide is enabled", + R"( +xxxx +xo.x +xx.x +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + // However, if suicide is disallowed, the value can be easily calculated iteratively + // So, it should be initialized by max number of moves + testAssert(12 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); +}, false); + + checkDotsField("Suicidal move is also legal move", +R"( +xxx +x.x +xxx +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(1 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); + boardWithMoveRecords.playMove(1, 1, P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); +}, true); +} + +void Tests::runDotsGroundingTests() { + cout << "Running dots grounding tests:" << endl; + + checkDotsField("Grounding propagation", +R"( +.x.. +o.o. +.x.. +.xo. +..x. +.... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(2 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(3 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + // Dot adjacent to WALL is already grounded + testAssert(isGrounded(boardWithMoveRecords.getState(1, 0))); + + // Ignore enemy's dots + testAssert(isGrounded(boardWithMoveRecords.getState(0, 1))); + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 1))); + + // Not yet grounded + testAssert(!isGrounded(boardWithMoveRecords.getState(1, 2))); + testAssert(!isGrounded(boardWithMoveRecords.getState(1, 3))); + + boardWithMoveRecords.playMove(1, 1, P_BLACK); + + testAssert(2 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(1 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(isGrounded(boardWithMoveRecords.getState(1, 1))); + + // Check grounding propagation + testAssert(isGrounded(boardWithMoveRecords.getState(1, 2))); + testAssert(isGrounded(boardWithMoveRecords.getState(1, 3))); + // Diagonal connection is not actual + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 4))); + + // Ignore enemy's dots + testAssert(isGrounded(boardWithMoveRecords.getState(0, 1))); + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 1))); + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 3))); +} + ); + + checkDotsField("Grounding propagation with empty base", + R"( +..x.. +.x.x. +.x.x. +..x.. +..... +)", + [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(0 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(5 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(!isGrounded(boardWithMoveRecords.getState(1, 2))); + testAssert(!isGrounded(boardWithMoveRecords.getState(3, 2))); + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 3))); + + boardWithMoveRecords.playMove(2, 2, P_WHITE); + + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(-1 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(isGrounded(boardWithMoveRecords.getState(2, 2))); + + testAssert(isGrounded(boardWithMoveRecords.getState(1, 2))); + testAssert(isGrounded(boardWithMoveRecords.getState(3, 2))); + testAssert(isGrounded(boardWithMoveRecords.getState(2, 3))); + }); + + checkDotsField("Grounding score with grounded base", +R"( +.x. +xox +... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(-1 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +} +); + + checkDotsField("Grounding score with ungrounded base", +R"( +..... +..o.. +.oxo. +..... +..... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 3, P_WHITE); + + testAssert(4 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(1 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +} +); + + checkDotsField("Grounding score with grounded and ungrounded bases", +R"( +.x..... +xox.o.. +...oxo. +....... +....... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + boardWithMoveRecords.playMove(4, 3, P_WHITE); + + testAssert(5 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(0 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +} +); + + checkDotsField("Grounding draw with ungrounded bases", +R"( +......... +..x...o.. +.xox.oxo. +......... +......... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 3, P_BLACK); + boardWithMoveRecords.playMove(6, 3, P_WHITE); + + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(5 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(5 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +} +); + + + checkDotsField("Grounding of real and empty adjacent bases", +R"( +..x.. +..x.. +.xox. +..... +.x.x. +..x.. +..... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(5 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 2))); + + boardWithMoveRecords.playMove(2, 3, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(2 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + // Real base becomes grounded + testAssert(isGrounded(boardWithMoveRecords.getState(2, 2))); + testAssert(isGrounded(boardWithMoveRecords.getState(2, 3))); + + // Grounding does not affect an empty location + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 4))); + // Grounding does not affect empty surrounding + testAssert(!isGrounded(boardWithMoveRecords.getState(3, 4))); +} +); + + checkDotsField("Grounding of real base when it touches grounded", +R"( +..x.. +..x.. +..... +.xox. +..x.. +..... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(3 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 3))); + testAssert(!isGrounded(boardWithMoveRecords.getState(2, 4))); + + boardWithMoveRecords.playMove(2, 2, P_BLACK); + + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(-1 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + testAssert(isGrounded(boardWithMoveRecords.getState(2, 3))); + testAssert(isGrounded(boardWithMoveRecords.getState(2, 4))); +} +); + + checkDotsField("Base inside base inside base and grounding score", +R"( +....... +..ooo.. +.o.x.o. +.oxoxo. +.o...o. +..o.o.. +....... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(12 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(3 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + boardWithMoveRecords.playMove(3, 4, P_BLACK); + + testAssert(12 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(4 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + boardWithMoveRecords.playMove(3, 5, P_WHITE); + + testAssert(13 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(4 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + boardWithMoveRecords.playMove(3, 6, P_WHITE); + + testAssert(-4 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(4 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +}); + + checkDotsField("Ground empty territory in case of dangling dots removing", +R"( +......... +..xxx.... +.x....x.. +.x.xx..x. +.x.x.x.x. +.x.xxx.x. +.x..xo.x. +..xxxxx.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(!isGrounded(boardWithMoveRecords.getState(4, 4))); + + boardWithMoveRecords.playMove(5, 1, P_BLACK); + boardWithMoveRecords.playGroundingMove(P_BLACK); + + // TODO: it should be grounded, however currently it's not possible to set the state correctly due to limitation of grounding algorithm. + //testAssert(isGrounded(boardWithMoveRecords.getState(4, 4))); +}); + + checkDotsField("Simple", + R"( +..... +.xxo. +..... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + + testAssert(2 == boardWithMoveRecords.board.numBlackCaptures); + + testAssert(1 == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + testAssert(boardWithMoveRecords.getWhiteScore() == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(2 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + testAssert(boardWithMoveRecords.getBlackScore() == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + + boardWithMoveRecords.undo(); + } +); + + checkDotsField("Draw", +R"( +.x... +.xxo. +...o. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(boardWithMoveRecords.getWhiteScore() == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getBlackScore() == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + boardWithMoveRecords.undo(); +} +); + + checkDotsField("Bases", +R"( +......... +..xx...x. +.xo.x.xox +..x...... +......... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 3, P_BLACK); + boardWithMoveRecords.playMove(7, 3, P_BLACK); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getWhiteScore() == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); +} +); + + checkDotsField("Multiple groups", +R"( +...... +xxo..o +.ox... +x...oo +...o.. +...... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getWhiteScore() == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(3 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getBlackScore() == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + boardWithMoveRecords.undo(); +} +); + + checkDotsField("Invalidate empty territory", +R"( +...... +..oo.. +.o..o. +..oo.. +...... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + Board board = boardWithMoveRecords.board; + + State state = boardWithMoveRecords.board.getState(Location::getLoc(2, 2, board.x_size)); + testAssert(C_WHITE == getEmptyTerritoryColor(state)); + + state = boardWithMoveRecords.board.getState(Location::getLoc(3, 2, board.x_size)); + testAssert(C_WHITE == getEmptyTerritoryColor(state)); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(6 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getBlackScore() == boardWithMoveRecords.board.blackScoreIfWhiteGrounds); + + state = boardWithMoveRecords.board.getState(Location::getLoc(2, 2, board.x_size)); + testAssert(C_EMPTY == getEmptyTerritoryColor(state)); + + state = boardWithMoveRecords.board.getState(Location::getLoc(3, 2, board.x_size)); + testAssert(C_EMPTY == getEmptyTerritoryColor(state)); +} +); + + checkDotsField("Don't invalidate empty territory for strong connection", +R"( +.x. +x.x +.x. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + const Board board = boardWithMoveRecords.board; + + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(boardWithMoveRecords.getWhiteScore() == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); + + State state = boardWithMoveRecords.board.getState(Location::getLoc(1, 1, board.x_size)); + testAssert(C_BLACK == getEmptyTerritoryColor(state)); + + state = boardWithMoveRecords.board.getState(Location::getLoc(0, 0, board.x_size)); + testAssert(C_EMPTY == getEmptyTerritoryColor(state)); +} +); +} + +void Tests::runDotsBoardHistoryGroundingTests() { + { + const Board board = parseDotsFieldDefault(R"( +.... +.xo. +.ox. +.... +)"); + const auto boardHistory = BoardHistory(board); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, false)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, false)); + + // No draw because there are some ungrounded dots + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.xo. +.xo. +.ox. +.ox. +)"); + const auto boardHistory = BoardHistory(board); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, false)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, false)); + + // Effective draw because all dots are grounded + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.x.... +xox... +....o. +...oxo +...... +)", {XYMove(1, 2, P_BLACK), XYMove(4, 4, P_WHITE)}); + const auto boardHistory = BoardHistory(board); + + // Also effective draw because all bases are grounded + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.x.... +xox.x. +...... +....o. +.o.oxo +...... +)", {XYMove(1, 2, P_BLACK), XYMove(4, 5, P_WHITE)}); + const auto boardHistory = BoardHistory(board); + + // No effective draw because there are ungrounded dots + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + Board board = parseDotsFieldDefault(R"( +..... +..o.. +.oxo. +..... +)"); + board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_WHITE); + testAssert(1 == board.numBlackCaptures); + const auto boardHistory = BoardHistory(board); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); + testAssert(1.0f == boardHistory.whiteScoreIfGroundingAlive(board)); + } + + { + Board board = parseDotsFieldDefault(R"( +..... +..x.. +.xox. +..... +)"); + board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_BLACK); + testAssert(1 == board.numWhiteCaptures); + const auto boardHistory = BoardHistory(board); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); + testAssert(-1.0f == boardHistory.whiteScoreIfGroundingAlive(board)); + } + + { + Board board = parseDotsFieldDefault(R"( +..... +..x.. +.xox. +..... +..... +)"); + board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_BLACK); + testAssert(1 == board.numWhiteCaptures); + const auto boardHistory = BoardHistory(board); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); + testAssert(std::isnan(boardHistory.whiteScoreIfGroundingAlive(board))); + } +} + +void Tests::runDotsPosHashTests() { + cout << "Running dots pos hash tests" << endl; + + { + Board dotsFieldWithEmptyBase = parseDotsFieldDefault(R"( +.x. +x.x +.x. +)", { XYMove(1, 1, P_WHITE) }); + + Board dotsFieldWithRealBase = parseDotsFieldDefault(R"( +.x. +xox +... +)", { XYMove(1, 2, P_BLACK) }); + + testAssert(dotsFieldWithEmptyBase.pos_hash == dotsFieldWithRealBase.pos_hash); + } + + { + Board dotsFieldWithSurrounding = parseDotsFieldDefault(R"( +..xxxxxx.. +.x......x. +x..x..o..x +x.xoxoxo.x +x........x +.x......x. +..xxx.xx.. +)", { XYMove(3, 4, P_BLACK), XYMove(6, 4, P_WHITE), XYMove(5, 6, P_BLACK) }); + testAssert(5 == dotsFieldWithSurrounding.numWhiteCaptures); + testAssert(0 == dotsFieldWithSurrounding.numBlackCaptures); + + Board dotsFieldWithErasedTerritory = parseDotsFieldDefault(R"( +..xxxxxx.. +.xxxxxxxx. +xxxxxxxxxx +xxxxxxxxxx +xxxxxxxxxx +.xxxxxxxx. +..xxxxxx.. +)"); + + testAssert(dotsFieldWithErasedTerritory.pos_hash == dotsFieldWithSurrounding.pos_hash); + } +} \ No newline at end of file diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp new file mode 100644 index 000000000..2be862836 --- /dev/null +++ b/cpp/tests/testdotsextra.cpp @@ -0,0 +1,464 @@ +#include "../tests/tests.h" +#include "testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace TestCommon; + +static void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardInput, const vector& extraMoves, const int symmetry) { + const Board transformedBoard = SymmetryHelpers::getSymBoard(initBoard, symmetry); + Board expectedBoard = parseDotsFieldDefault(expectedSymmetryBoardInput); + for (const XYMove& extraMove : extraMoves) { + expectedBoard.playMoveAssumeLegal(SymmetryHelpers::getSymLoc(extraMove.x, extraMove.y, initBoard, symmetry), extraMove.player); + } + expect(SymmetryHelpers::symmetryToString(symmetry).c_str(), Board::toStringSimple(transformedBoard), Board::toStringSimple(expectedBoard)); + testAssert(transformedBoard.isEqualForTesting(expectedBoard)); +} + +void Tests::runDotsSymmetryTests() { + cout << "Running dots symmetry tests" << endl; + + Board initialBoard = parseDotsFieldDefault(R"( +...ox +..ox. +.o.ox +.xo.. +)"); + initialBoard.playMoveAssumeLegal(Location::getLoc(4, 1, initialBoard.x_size), P_WHITE); + testAssert(1 == initialBoard.numBlackCaptures); + + checkSymmetry(initialBoard, R"( +...ox +..ox. +.o.ox +.xo.. +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_NONE); + + checkSymmetry(initialBoard, R"( +.xo.. +.o.ox +..ox. +...ox +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_FLIP_Y); + + checkSymmetry(initialBoard, R"( +xo... +.xo.. +xo.o. +..ox. +)", +{ XYMove(4, 1, P_WHITE)}, + SymmetryHelpers::SYMMETRY_FLIP_X); + + checkSymmetry(initialBoard, R"( +..ox. +xo.o. +.xo.. +xo... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_FLIP_Y_X); + + checkSymmetry(initialBoard, R"( +.... +..ox +.o.o +oxo. +x.x. +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE); + + checkSymmetry(initialBoard, R"( +.... +xo.. +o.o. +.oxo +.x.x +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_X); + + checkSymmetry(initialBoard, R"( +x.x. +oxo. +.o.o +..ox +.... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y); + + checkSymmetry(initialBoard, R"( +.x.x +.oxo +o.o. +xo.. +.... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); + + cout << "Check dots symmetry with start pos" << endl; + const auto originalRules = Rules(Rules::DEFAULT_DOTS.startPos, false, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots); + auto board = Board(5, 4, originalRules); + board.setStartPos(DOTS_RANDOM); + board.playMoveAssumeLegal(Location::getLoc(1, 2, board.x_size), P_BLACK); + + const auto rotatedBoard = SymmetryHelpers::getSymBoard(board, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_X); + + auto rulesAfterTransformation = originalRules; + rulesAfterTransformation.startPosIsRandom = true; + auto expectedBoard = Board(4, 5, rulesAfterTransformation); + expectedBoard.setStonesFailIfNoLibs({ + Move(Location::getLoc(2, 2, expectedBoard.x_size), P_BLACK), + Move(Location::getLoc(2, 3, expectedBoard.x_size), P_WHITE), + Move(Location::getLoc(1, 3, expectedBoard.x_size), P_BLACK), + Move(Location::getLoc(1, 2, expectedBoard.x_size), P_WHITE), + }, true); + expectedBoard.playMoveAssumeLegal(Location::getLoc(1, 1, expectedBoard.x_size), P_BLACK); + + expect("Dots symmetry with start pos", Board::toStringSimple(rotatedBoard), Board::toStringSimple(expectedBoard)); + testAssert(rotatedBoard.isEqualForTesting(expectedBoard)); + + const auto unrotatedBoard = SymmetryHelpers::getSymBoard(rotatedBoard, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y); + testAssert(board.isEqualForTesting(unrotatedBoard)); +} + +static string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { + const Board board = parseDotsFieldDefault(boardData, extraMoves); + + Color result[Board::MAX_ARR_SIZE]; + const int whiteScore = board.calculateOwnershipAndWhiteScore(result, groundingPlayer); + testAssert(expectedWhiteScore == whiteScore); + + std::ostringstream oss; + + for (int y = 0; y < board.y_size; y++) { + for (int x = 0; x < board.x_size; x++) { + const Loc loc = Location::getLoc(x, y, board.x_size); + oss << PlayerIO::colorToChar(result[loc]); + } + oss << endl; + } + + return oss.str(); +} + +static void expect( + const char* name, + const Color groundingPlayer, + const std::string& actualField, + const std::string& expectedOwnership, + const int expectedWhiteScore, + const vector& extraMoves = {} +) { + cout << " " << name << ", Grounding Player: " << PlayerIO::colorToChar(groundingPlayer) << endl; + expect(name, getOwnership(actualField, groundingPlayer, expectedWhiteScore, extraMoves), expectedOwnership); +} + +void Tests::runDotsOwnershipTests() { + expect("Start Cross", C_EMPTY, R"( +...... +...... +..ox.. +..xo.. +...... +...... +)", + R"( +...... +...... +...... +...... +...... +...... +)", 0); + + expect("Wins by a base", C_EMPTY, R"( +...... +...... +..ox.. +.oxo.. +...... +...... +)", +R"( +...... +...... +...... +..O... +...... +...... +)", 1, {XYMove(2, 4, P_WHITE)}); + + expect("Loss by grounding", C_BLACK, R"( +..o... +..o... +..ox.. +..xo.. +...o.. +...o.. +)", +R"( +...... +...... +...O.. +..O... +...... +...... +)", 2); + + expect("Loss by grounding", C_WHITE, R"( +...x.. +...x.. +..ox.. +..xo.. +..x... +..x... +)", +R"( +...... +...... +..X... +...X.. +...... +...... +)", -2); + + expect("Wins by grounding with an ungrounded dot", C_WHITE, R"( +...... +.oox.. +.xxo.. +.oo... +....o. +...... +)", +R"( +...... +...... +.OO... +...... +....X. +...... +)", 1, {XYMove(0, 2, P_WHITE)}); +} + +static std::pair getCapturingAndBases( + const string& boardData, + const bool suicide, + const bool captureEmptyBases, + const vector& extraMoves +) { + const Board board = parseDotsField(boardData, false, suicide, captureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); + + const Board& copy(board); + + vector captures; + vector bases; + copy.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + + std::ostringstream capturesStringStream; + std::ostringstream basesStringStream; + + for (int y = 0; y < copy.y_size; y++) { + for (int x = 0; x < copy.x_size; x++) { + const Loc loc = Location::getLoc(x, y, copy.x_size); + const Color captureColor = captures[loc]; + if (captureColor == C_WALL) { + capturesStringStream << PlayerIO::colorToChar(P_BLACK) << PlayerIO::colorToChar(P_WHITE); + } else { + capturesStringStream << PlayerIO::colorToChar(captureColor) << " "; + } + + Color baseColor = bases[loc]; + if (baseColor == C_WALL) { + basesStringStream << PlayerIO::colorToChar(P_BLACK) << PlayerIO::colorToChar(P_WHITE); + } else { + basesStringStream << PlayerIO::colorToChar(baseColor) << " "; + } + + if (x < copy.x_size - 1) { + capturesStringStream << " "; + basesStringStream << " "; + } + } + capturesStringStream << endl; + basesStringStream << endl; + } + + // Make sure we didn't change an internal state during calculating + testAssert(board.isEqualForTesting(copy)); + + return {capturesStringStream.str(), basesStringStream.str()}; +} + +static void checkCapturingAndBase( + const string& title, + const string& boardData, + const string& expectedCaptures, + const string& expectedBases, + const bool suicide = Rules::DEFAULT_DOTS.multiStoneSuicideLegal, + const bool captureEmptyBases = Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, + const vector& extraMoves = {} +) { + auto [capturing, bases] = getCapturingAndBases(boardData, suicide, captureEmptyBases, extraMoves); + cout << (" " + title + ": capturing").c_str() << endl; + expect("", capturing, expectedCaptures); + cout << (" " + title + ": bases").c_str() << endl; + expect("", bases, expectedBases); +} + +void Tests::runDotsCapturingTests() { + cout << "Running dots capturing tests" << endl; + + checkCapturingAndBase( + "Two bases", + R"( +.x...o. +xox.oxo +....... +)", R"( +. . . . . . . +. . . . . . . +. X . . . O . +)", + R"( +. . . . . . . +. X . . . O . +. . . . . . . +)" +); + + checkCapturingAndBase( + "Overlapping capturing location", + R"( +.x. +xox +... +oxo +.o. +)", R"( +. . . +. . . +. XO . +. . . +. . . +)", + R"( +. . . +. X . +. . . +. O . +. . . +)" +); + + checkCapturingAndBase( + "Empty base", + R"( +.x. +x.x +.x. +)", R"( +. . . +. . . +. . . +)", +R"( +. . . +. X . +. . . +)" +); + + checkCapturingAndBase( +"Empty base can be broken", +R"( +.xx. +x..x +x.x. +oxo. +.o.. +)", R"( +. . . . +. . . . +. O . . +. . . . +. . . . +)", +R"( +. . . . +. X X . +. X . . +. O . . +. . . . +)" +); + + checkCapturingAndBase( +"No empty base capturing", +R"( +.x. +x.x +... +)", R"( +. . . +. . . +. . . +)", +R"( +. . . +. . . +. . . +)", Rules::DEFAULT_DOTS.multiStoneSuicideLegal, false +); + + checkCapturingAndBase( +"Empty base capturing", +R"( +.x. +x.x +... +)", R"( +. . . +. . . +. X . +)", +R"( +. . . +. X . +. . . +)", Rules::DEFAULT_DOTS.multiStoneSuicideLegal, true +); + + checkCapturingAndBase( + "Complex example with overlapping of capturing and bases", + R"( +.ooxx. +o.xo.x +ox.ox. +ox.ox. +.o.x.. +)", R"( +. . . . . . +. . . . . . +. . . . . . +. . XO . . . +. . XO . . . +)", + R"( +. . . . . . +. O O X X . +. O XO X . . +. O XO X . . +. . . . . . +)" +); +} \ No newline at end of file diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp new file mode 100644 index 000000000..84c0553c4 --- /dev/null +++ b/cpp/tests/testdotsstartposes.cpp @@ -0,0 +1,408 @@ +#include + +#include "../tests/tests.h" +#include "testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace std::chrono; +using namespace TestCommon; + +static void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startPosIsRandom, const Board& board) { + std::ostringstream sgfStringStream; + const BoardHistory boardHistory(board, P_BLACK, board.rules, 0); + WriteSgf::writeSgf(sgfStringStream, "black", "white", boardHistory, {}); + const string sgfString = sgfStringStream.str(); + cout << "; Sgf: " << sgfString << endl; + + const auto deserializedSgf = Sgf::parse(sgfString); + const Rules newRules = deserializedSgf->getRulesOrFail(); + testAssert(startPos == newRules.startPos); + testAssert(startPosIsRandom == newRules.startPosIsRandom); +} + +static void checkStartPos(const string& description, const int startPos, const bool startPosIsRandom, const int x_size, const int y_size, const string& expectedBoard = "", const vector& extraMoves = {}) { + cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; + + auto board = Board(x_size, y_size, Rules(startPos, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + board.setStartPos(DOTS_RANDOM); + playXYMovesAssumeLegal(board, extraMoves); + + std::ostringstream oss; + Board::printBoard(oss, board, Board::NULL_LOC, nullptr, false); + + if (!expectedBoard.empty()) { + expect(description.c_str(), oss, expectedBoard); + } + + writeToSgfAndCheckStartPosFromSgfProp(startPos, startPosIsRandom, board); +} + +static void checkRecognition(const vector& xyMoves, const int x_size, const int y_size, + const int expectedStartPos, + const vector& expectedStartMoves, + const bool expectedRandomized, + const vector& expectedRemainingMoves) { + + auto moves = vector(); + moves.reserve(xyMoves.size()); + for (const auto xyMove : xyMoves) { + moves.push_back(xyMove.toMove(x_size)); + } + + vector actualStartPosMoves; + bool actualRandomized; + vector actualRemainingMoves; + + testAssert(expectedStartPos == Rules::recognizeStartPos(moves, x_size, y_size, actualStartPosMoves, actualRandomized, &actualRemainingMoves)); + + testAssert(expectedStartMoves.size() == actualStartPosMoves.size()); + for (size_t i = 0; i < expectedStartMoves.size(); ++i) { + testAssert(movesEqual(expectedStartMoves[i].toMove(x_size), actualStartPosMoves[i])); + } + + testAssert(expectedRandomized == actualRandomized); + + testAssert(expectedRemainingMoves.size() == actualRemainingMoves.size()); + for (size_t i = 0; i < expectedRemainingMoves.size(); ++i) { + testAssert(movesEqual(expectedRemainingMoves[i].toMove(x_size), actualRemainingMoves[i])); + } +} + +static void checkStartPosRecognition(const string& description, const int expectedStartPos, const bool startPosIsRandom, const string& inputBoard) { + const Board board = parseDotsField(inputBoard, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, {}); + + cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; + + writeToSgfAndCheckStartPosFromSgfProp(expectedStartPos, startPosIsRandom, board); +} + +static void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { + const auto generatedMoves = Rules::generateStartPos(startPos, startPosIsRandom ? &DOTS_RANDOM : nullptr, 39, 32); + vector actualStartPosMoves; + bool actualRandomized; + testAssert(startPos == Rules::recognizeStartPos(generatedMoves, 39, 32, actualStartPosMoves, actualRandomized)); + // We can't reliably check in case of randomization is not detected because random generator can + // generate static poses in rare cases. + if (actualRandomized) { + testAssert(startPosIsRandom); + } +} + +void Tests::runDotsStartPosTests() { + cout << "Running dots start pos tests" << endl; + + Rand rand("runDotsStartPosTests"); + + checkStartPos("Cross on minimal size", Rules::START_POS_CROSS, false, 2, 2, R"( + 1 2 + 2 X O + 1 O X +)"); + + checkStartPos("Extra dots with cross (for instance, a handicap game)", Rules::START_POS_CROSS, false, 4, 4, R"( + 1 2 3 4 + 4 . . . . + 3 . X O . + 2 . O X . + 1 . . X . +)", {XYMove(2, 3, P_BLACK)}); + + checkStartPosRecognition("Empty start pos with three extra moves", Rules::START_POS_EMPTY, false, R"( +.... +.xo. +.o.. +.... +)"); + + checkStartPosRecognition("Reversed cross should be recognized as random", Rules::START_POS_CROSS, true, R"( +.... +.ox. +.xo. +.... +)"); + + checkStartPos("Cross on odd size", Rules::START_POS_CROSS, false, 3, 3, R"( + 1 2 3 + 3 . X O + 2 . O X + 1 . . . +)"); + + checkStartPos("Cross on standard size", Rules::START_POS_CROSS, false, 39, 32, R"( + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +21 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . . X O . . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . . O X . . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +11 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); + + checkStartPos("Double cross on minimal size", Rules::START_POS_CROSS_2, false, 4, 2, R"( + 1 2 3 4 + 2 X O O X + 1 O X X O +)"); + + checkStartPos("Double cross on odd size", Rules::START_POS_CROSS_2, false, 5, 3, R"( + 1 2 3 4 5 + 3 . X O O X + 2 . O X X O + 1 . . . . . +)"); + + checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 6, 4, R"( + 1 2 3 4 5 6 + 4 . . . . . . + 3 . X O O X . + 2 . O X X O . + 1 . . . . . . +)"); + + checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 7, 4, R"( + 1 2 3 4 5 6 7 + 4 . . . . . . . + 3 . . X O O X . + 2 . . O X X O . + 1 . . . . . . . +)"); + + checkStartPos("Double rand cross on triple cross", Rules::START_POS_CROSS_2, false, 10, 4, R"( + 1 2 3 4 5 6 7 8 9 10 + 4 . . . . . . . . . . + 3 . . . X O O X . X O + 2 . . . O X X O . O X + 1 . . . . . . . . . . +)", {XYMove(8, 1, P_BLACK), XYMove(9, 1, P_WHITE), XYMove(9, 2, P_BLACK), XYMove(8, 2, P_WHITE)}); + + const vector expectedRemainingMovesForDoubleCross = { + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }; + + // Double cross exactly matches static double cross (not randomized) + checkRecognition({ + XYMove(3, 1, P_BLACK), + XYMove(4, 1, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(3, 2, P_WHITE), + + XYMove(5, 1, P_WHITE), + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }, 10, 4, Rules::START_POS_CROSS_2, { + XYMove(3, 1, P_BLACK), + XYMove(4, 1, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(3, 2, P_WHITE), + + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + XYMove(5, 1, P_WHITE), + }, false, expectedRemainingMovesForDoubleCross +); + + // Double cross partially matches static double cross (randomized) + checkRecognition({ + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(5, 1, P_WHITE), + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) +}, 10, 4, Rules::START_POS_CROSS_2, +{ + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + XYMove(5, 1, P_WHITE), + + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), +}, true, expectedRemainingMovesForDoubleCross); + + // Double cross do not match static cross completely (randomized) + checkRecognition({ + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(4, 1, P_WHITE), + XYMove(5, 1, P_BLACK), + XYMove(5, 2, P_WHITE), + XYMove(4, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }, 10, 4, Rules::START_POS_CROSS_2, { + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(5, 1, P_BLACK), + XYMove(5, 2, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(4, 1, P_WHITE) + }, true, expectedRemainingMovesForDoubleCross); + + checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, false, 39, 32, R"( + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +21 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . X O O X . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . O X X O . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +11 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 5, 5, R"( + 1 2 3 4 5 + 5 X O . X O + 4 O X . O X + 3 . . . . . + 2 X O . X O + 1 O X . O X +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 7, 7, R"( + 1 2 3 4 5 6 7 + 7 . . . . . . . + 6 . X O . X O . + 5 . O X . O X . + 4 . . . . . . . + 3 . X O . X O . + 2 . O X . O X . + 1 . . . . . . . +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 8, 8, R"( + 1 2 3 4 5 6 7 8 + 8 . . . . . . . . + 7 . X O . . X O . + 6 . O X . . O X . + 5 . . . . . . . . + 4 . . . . . . . . + 3 . X O . . X O . + 2 . O X . . O X . + 1 . . . . . . . . +)"); + + checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, false, 39, 32, R"( + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . X O . . . . . . . . . . . . . X O . . . . . . . . . . . +21 . . . . . . . . . . . O X . . . . . . . . . . . . . O X . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . X O . . . . . . . . . . . . . X O . . . . . . . . . . . +11 . . . . . . . . . . . O X . . . . . . . . . . . . . O X . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); + + checkStartPos("Random quadruple cross on standard size", Rules::START_POS_CROSS_4, true, 39, 32); + + checkGenerationAndRecognition(Rules::START_POS_CROSS_4, false); + checkGenerationAndRecognition(Rules::START_POS_CROSS_4, true); +} \ No newline at end of file diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp new file mode 100644 index 000000000..c6da9a6df --- /dev/null +++ b/cpp/tests/testdotsstress.cpp @@ -0,0 +1,336 @@ +#include + +#include "../tests/tests.h" +#include "../tests/testdotsutils.h" + +using namespace std; +using namespace std::chrono; +using namespace TestCommon; + +static string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { + Board boardCopy(initialBoard); + BoardHistory boardHistory(boardCopy, P_BLACK, boardCopy.rules, 0); + for (const Board::MoveRecord& moveRecord : moveRecords) { + boardHistory.makeBoardMoveAssumeLegal(boardCopy, moveRecord.loc, moveRecord.pla, nullptr); + } + std::ostringstream sgfStringStream; + WriteSgf::writeSgf(sgfStringStream, "blue", "red", boardHistory, {}); + return sgfStringStream.str(); +} + +/** + * Calculates the grounding and result captures without using the grounding flag and incremental calculations. + * It's used for testing to verify incremental grounding algorithms. + */ +static void validateGrounding( + const Board& boardBeforeGrounding, + const Board& boardAfterGrounding, + const Player pla, + const vector& moveRecords) { + unordered_set visited_locs; + assert(pla == P_BLACK || pla == P_WHITE); + + int expectedNumBlackCaptures = 0; + int expectedNumWhiteCaptures = 0; + const Player opp = getOpp(pla); + for (int y = 0; y < boardBeforeGrounding.y_size; y++) { + for (int x = 0; x < boardBeforeGrounding.x_size; x++) { + Loc loc = Location::getLoc(x, y, boardBeforeGrounding.x_size); + const State state = boardBeforeGrounding.getState(Location::getLoc(x, y, boardBeforeGrounding.x_size)); + + if (const Color activeColor = getActiveColor(state); activeColor == pla) { + if (visited_locs.count(loc) > 0) + continue; + + bool grounded = false; + + vector walkStack; + vector baseLocs; + walkStack.push_back(loc); + + // Find active territory and calculate its grounding state. + while (!walkStack.empty()) { + Loc curLoc = walkStack.back(); + walkStack.pop_back(); + + if (const Color curActiveColor = getActiveColor(boardBeforeGrounding.getState(curLoc)); curActiveColor == pla) { + if (visited_locs.count(curLoc) == 0) { + visited_locs.insert(curLoc); + baseLocs.push_back(curLoc); + boardBeforeGrounding.forEachAdjacent(curLoc, [&](const Loc& adjLoc) { + walkStack.push_back(adjLoc); + }); + } + } else if (curActiveColor == C_WALL) { + grounded = true; + } + } + + for (const Loc& baseLoc : baseLocs) { + const Color placedDotColor = getPlacedDotColor(boardBeforeGrounding.getState(baseLoc)); + + if (!grounded) { + // If the territory is not grounded, it becomes dead. + // Freed dots don't count because they don't add a value to the opp score (assume they just become be placed). + if (placedDotColor == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } else { + State baseLocState = boardAfterGrounding.getState(baseLoc); + // This check on placed dot color is redundant. + // However, currently it's not possible to always ground empty locs in some rare cases due to limitations of incremental grounding algorithm. + // Fortunately, they don't affect the resulting score. + if (!isGrounded(baseLocState) && getPlacedDotColor(baseLocState) != C_EMPTY) { + Global::fatalError("Loc (" + to_string(Location::getX(baseLoc, boardBeforeGrounding.x_size)) + "; " + + to_string(Location::getY(baseLoc, boardBeforeGrounding.x_size)) + ") " + + " should be grounded. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } + + // If the territory is grounded, count dead dots of the opp player. + if (placedDotColor == opp) { + if (pla == P_BLACK) { + expectedNumWhiteCaptures++; + } else { + expectedNumBlackCaptures++; + } + } + } + } + } else if (activeColor == opp) { // In the case of opp active color, counts only captured dots + if (getPlacedDotColor(state) == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } + } + } + + if (expectedNumBlackCaptures != boardAfterGrounding.numBlackCaptures || expectedNumWhiteCaptures != boardAfterGrounding.numWhiteCaptures) { + Global::fatalError("expectedNumBlackCaptures (" + to_string(expectedNumBlackCaptures) + ")" + + " == board.numBlackCaptures (" + to_string(boardAfterGrounding.numBlackCaptures) + ")" + + " && expectedNumWhiteCaptures (" + to_string(expectedNumWhiteCaptures) + ")" + + " == board.numWhiteCaptures (" + to_string(boardAfterGrounding.numWhiteCaptures) + ")" + + " check is failed. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } +} + +static void validateStatesAndCaptures(const Board& board, const vector& moveRecords) { + int expectedNumBlackCaptures = 0; + int expectedNumWhiteCaptures = 0; + int expectedPlacedDotsCount = -board.rules.getNumOfStartPosStones(); + + for (int y = 0; y < board.y_size; y++) { + for (int x = 0; x < board.x_size; x++) { + const State state = board.getState(Location::getLoc(x, y, board.x_size)); + const Color activeColor = getActiveColor(state); + const Color placedDotColor = getPlacedDotColor(state); + [[maybe_unused]] const Color emptyTerritoryColor = getEmptyTerritoryColor(state); + + if (placedDotColor != C_EMPTY) { + expectedPlacedDotsCount++; + } + + if (activeColor == C_BLACK) { + assert(C_EMPTY == emptyTerritoryColor); + if (placedDotColor == C_WHITE) { + expectedNumWhiteCaptures++; + } + } else if (activeColor == C_WHITE) { + assert(C_EMPTY == emptyTerritoryColor); + if (placedDotColor == C_BLACK) { + expectedNumBlackCaptures++; + } + } else { + assert(placedDotColor == C_EMPTY); + //assert(!isTerritory(state)); + } + } + } + + const int actualPlacedDotsCount = moveRecords.size() - (moveRecords.back().loc == Board::PASS_LOC ? 1 : 0); + testAssert(expectedPlacedDotsCount == actualPlacedDotsCount); + testAssert(expectedNumBlackCaptures == board.numBlackCaptures); + testAssert(expectedNumWhiteCaptures == board.numWhiteCaptures); +} + +static void runDotsStressTestsInternal( + int x_size, + int y_size, + int gamesCount, + bool dotsGame, + int startPos, + bool startPosIsRandom, + bool dotsCaptureEmptyBase, + float komi, + bool suicideAllowed, + float groundingStartCoef, + float groundingEndCoef, + bool performExtraChecks) { + assert(groundingStartCoef >= 0 && groundingStartCoef <= 1); + assert(groundingEndCoef >= 0 && groundingEndCoef <= 1); + assert(groundingEndCoef >= groundingStartCoef); + + cout << " Random games" << endl; + cout << " Game type: " << (dotsGame ? "Dots" : "Go") << endl; + cout << " Start position: " << Rules::writeStartPosRule(startPos) << endl; + cout << " Start position is random: " << boolalpha << startPosIsRandom << endl; + if (dotsGame) { + cout << " Capture empty bases: " << boolalpha << dotsCaptureEmptyBase << endl; + } + cout << " Extra checks: " << boolalpha << performExtraChecks << endl; +#ifdef NDEBUG + cout << " Build: Release" << endl; +#else + cout << " Build: Debug" << endl; +#endif + cout << " Size: " << x_size << ":" << y_size << endl; + cout << " Komi: " << komi << endl; + cout << " Suicide: " << boolalpha << suicideAllowed << endl; + cout << " Games count: " << gamesCount << endl; + + const auto start = high_resolution_clock::now(); + + Rand rand("runDotsStressTests"); + + Rules rules = dotsGame + ? Rules(startPos, startPosIsRandom, suicideAllowed, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) + : Rules(); + int numLegalMoves = x_size * y_size - rules.getNumOfStartPosStones(); + + vector randomMoves = vector(); + randomMoves.reserve(numLegalMoves); + + for(int y = 0; y < y_size; y++) { + for(int x = 0; x < x_size; x++) { + randomMoves.push_back(Location::getLoc(x, y, x_size)); + } + } + + int movesCount = 0; + int blackWinsCount = 0; + int whiteWinsCount = 0; + int drawsCount = 0; + int groundingCount = 0; + + auto moveRecords = vector(); + + for (int n = 0; n < gamesCount; n++) { + rand.shuffle(randomMoves); + moveRecords.clear(); + + auto initialBoard = Board(x_size, y_size, rules); + initialBoard.setStartPos(DOTS_RANDOM); + auto board = initialBoard; + + Loc lastLoc = Board::NULL_LOC; + + int tryGroundingAfterMove = static_cast((groundingStartCoef + static_cast(rand.nextDouble()) * (groundingEndCoef - groundingStartCoef)) * static_cast(numLegalMoves)); + Player pla = P_BLACK; + int currentGameMovesCount = 0; + for(short randomMove : randomMoves) { + lastLoc = currentGameMovesCount >= tryGroundingAfterMove ? Board::PASS_LOC : randomMove; + + if (board.isLegal(lastLoc, pla, suicideAllowed, false)) { + if (performExtraChecks) { + Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); + moveRecords.push_back(moveRecord); + } else { + board.playMoveAssumeLegal(lastLoc, pla); + } + currentGameMovesCount++; + pla = getOpp(pla); + } + + if (lastLoc == Board::PASS_LOC) { + groundingCount++; + int scoreDiff; + int oppScoreIfGrounding; + if (Player lastPla = getOpp(pla); lastPla == P_BLACK) { + scoreDiff = board.numBlackCaptures - board.numWhiteCaptures; + oppScoreIfGrounding = board.whiteScoreIfBlackGrounds; + } else { + scoreDiff = board.numWhiteCaptures - board.numBlackCaptures; + oppScoreIfGrounding = board.blackScoreIfWhiteGrounds; + } + if (scoreDiff != oppScoreIfGrounding) { + Global::fatalError("scoreDiff (" + to_string(scoreDiff) + ") == oppScoreIfGrounding (" + to_string(oppScoreIfGrounding) + ") check is failed. " + + "Sgf: " + moveRecordsToSgf(initialBoard, moveRecords)); + } + break; + } + } + + if (performExtraChecks) { + if (lastLoc == Board::PASS_LOC) { + Board boardBeforeGrounding(board); + boardBeforeGrounding.undo(moveRecords.back()); + validateGrounding(boardBeforeGrounding, board, moveRecords.back().pla, moveRecords); + } + validateStatesAndCaptures(board, moveRecords); + } + + if (dotsGame && suicideAllowed && lastLoc != Board::PASS_LOC) { + testAssert(0 == board.numLegalMovesIfSuiAllowed); + } + + movesCount += currentGameMovesCount; + if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > Global::FLOAT_EPS) { + whiteWinsCount++; + } else if (whiteScore < -Global::FLOAT_EPS) { + blackWinsCount++; + } else { + drawsCount++; + } + + if (performExtraChecks) { + while (!moveRecords.empty()) { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } + + testAssert(initialBoard.isEqualForTesting(board)); + } + } + + const auto end = high_resolution_clock::now(); + auto durationNs = duration_cast(end - start); + + cout.precision(4); + cout << " Elapsed time: " << duration_cast(durationNs).count() << " ms" << endl; + cout << " Number of games per second: " << static_cast(static_cast(gamesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per second: " << static_cast(static_cast(movesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per game: " << static_cast(static_cast(movesCount) / gamesCount) << endl; + cout << " Time per game: " << static_cast(durationNs.count()) / gamesCount / 1000000 << " ms" << endl; + cout << " Black wins: " << blackWinsCount << " (" << static_cast(blackWinsCount) / gamesCount << ")" << endl; + cout << " White wins: " << whiteWinsCount << " (" << static_cast(whiteWinsCount) / gamesCount << ")" << endl; + cout << " Draws: " << drawsCount << " (" << static_cast(drawsCount) / gamesCount << ")" << endl; + cout << " Groundings: " << groundingCount << " (" << static_cast(groundingCount) / gamesCount << ")" << endl; +} + +void Tests::runDotsStressTests() { + cout << "Running dots stress tests" << endl; + + cout << " Max territory" << endl; + auto board = Board(39, 32, Rules(Rules::START_POS_EMPTY, false, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + for(int y = 0; y < board.y_size; y++) { + for(int x = 0; x < board.x_size; x++) { + const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; + board.playMoveAssumeLegal(Location::getLoc(x, y, board.x_size), pla); + } + } + testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); + testAssert(0 == board.numLegalMovesIfSuiAllowed); + + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, false, 0.0f, true, 0.8f, 1.0f, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, true, 0.5f, false, 0.8f, 1.0f, true); + + runDotsStressTestsInternal(39, 32, 50000, true, Rules::START_POS_CROSS, false, false, 0.0f, true, 0.8f, 1.0f, false); + runDotsStressTestsInternal(39, 32, 50000, true, Rules::START_POS_CROSS_4, true, false, 0.0f, true, 0.8f, 1.0f, false); +} \ No newline at end of file diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp new file mode 100644 index 000000000..f92686338 --- /dev/null +++ b/cpp/tests/testdotsutils.cpp @@ -0,0 +1,43 @@ +#include "testdotsutils.h" + +using namespace std; + +Move XYMove::toMove(const int x_size) const { + return Move(Location::getLoc(x, y, x_size), player); +} + +Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { + return parseDotsField(input, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); +} + +Board parseDotsField(const string& input, const bool startPosIsRandom, const bool suicide, const bool captureEmptyBases, + const bool freeCapturedDots, const vector& extraMoves) { + int currentXSize = 0; + int xSize = -1; + int ySize = 0; + for(int i = 0; i <= input.length(); i++) { + if(i == input.length() - 1 || input[i] == '\n') { + if(i > 0) { + if(xSize != -1) { + assert(xSize == currentXSize); + } else { + xSize = currentXSize; + } + currentXSize = 0; + ySize++; + } + } else { + currentXSize++; + } + } + + Board result = Board::parseBoard(xSize, ySize, input, Rules(Rules::START_POS_EMPTY, startPosIsRandom, suicide, captureEmptyBases, freeCapturedDots)); + playXYMovesAssumeLegal(result, extraMoves); + return result; +} + +void playXYMovesAssumeLegal(Board& board, const vector& moves) { + for(const XYMove& move : moves) { + board.playMoveAssumeLegal(Location::getLoc(move.x, move.y, board.x_size), move.player); + } +} \ No newline at end of file diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h new file mode 100644 index 000000000..169e1b3b5 --- /dev/null +++ b/cpp/tests/testdotsutils.h @@ -0,0 +1,72 @@ +#pragma once + +#include "../program/playutils.h" + +using namespace std; + +inline Rand DOTS_RANDOM("DOTS_RANDOM"); + +struct XYMove { + int x; + int y; + Player player; + + XYMove(const int newX, const int newY, const Player newPlayer) : x(newX), y(newY), player(newPlayer) {} + + [[nodiscard]] std::string toString() const { + return "(" + to_string(x) + "," + to_string(y) + "," + PlayerIO::colorToChar(player) + ")"; + } + + [[nodiscard]] Move toMove(int x_size) const; +}; + +struct BoardWithMoveRecords { + Board& board; + vector& moveRecords; + + BoardWithMoveRecords(Board& initBoard, vector& initMoveRecords) : board(initBoard), moveRecords(initMoveRecords) {} + + void playMove(const int x, const int y, const Player player) const { + moveRecords.push_back(board.playMoveRecorded(Location::getLoc(x, y, board.x_size), player)); + } + + void playGroundingMove(const Player player) const { + moveRecords.push_back(board.playMoveRecorded(Board::PASS_LOC, player)); + } + + [[nodiscard]] State getState(const int x, const int y) const { + return board.getState(Location::getLoc(x, y, board.x_size)); + } + + [[nodiscard]] bool isLegal(const int x, const int y, const Player player) const { + return board.isLegal(Location::getLoc(x, y, board.x_size), player, true, false); + } + + [[nodiscard]] bool isSuicide(const int x, const int y, const Player player) const { + return board.isSuicide(Location::getLoc(x, y, board.x_size), player); + } + + [[nodiscard]] bool wouldBeCapture(const int x, const int y, const Player player) const { + return board.wouldBeCapture(Location::getLoc(x, y, board.x_size), player); + } + + int getWhiteScore() const { + return board.numBlackCaptures - board.numWhiteCaptures; + } + + int getBlackScore() const { + return -getWhiteScore(); + } + + void undo() const { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } +}; + +Board parseDotsFieldDefault(const string& input, const vector& extraMoves = {}); + +Board parseDotsField(const string& input, bool startPosIsRandom, bool suicide, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); + +void playXYMovesAssumeLegal(Board& board, const vector& moves); + diff --git a/cpp/tests/testnnevalcanary.cpp b/cpp/tests/testnnevalcanary.cpp index f6c07a0bc..d207f6876 100644 --- a/cpp/tests/testnnevalcanary.cpp +++ b/cpp/tests/testnnevalcanary.cpp @@ -44,12 +44,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[pd];W[pp];B[dd];W[dp];B[qn];W[nq];B[cq];W[dq];B[cp];W[do];B[bn];W[cc];B[cd];W[dc];B[ec];W[eb];B[fb];W[fc];B[ed];W[gb];B[db];W[fa];B[cb];W[qo];B[pn];W[nc];B[qj];W[qc];B[qd];W[pc];B[od];W[nd];B[ne];W[me];B[mf];W[nf])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 18; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -79,12 +77,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[pd];W[pp];B[dd];W[dp];B[qn];W[nq];B[cq];W[dq];B[cp];W[do];B[bn];W[cc];B[cd];W[dc];B[ec];W[eb];B[fb];W[fc];B[ed];W[gb];B[db];W[fa];B[cb];W[qo];B[pn];W[nc];B[qj];W[qc];B[qd];W[pc];B[od];W[nd];B[ne];W[me];B[mf];W[nf])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 36; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -113,12 +109,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -148,12 +142,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); hist.setKomi(-7); MiscNNInputParams nnInputParams; @@ -180,12 +172,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); hist.setKomi(21); MiscNNInputParams nnInputParams; @@ -212,12 +202,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;FF[4]GM[1]CA[UTF-8]RU[Japanese]KM[6]SZ[16:11];B[md];W[nh];B[dh];W[cd];B[lh];W[li];B[ki])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 7; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -470,8 +458,8 @@ static std::shared_ptr nnOutputOfJson(const std::string& s) { nnOutput->policyOptimismUsed = input["policyOptimismUsed"].get(); nnOutput->nnXLen = input["nnXLen"].get(); nnOutput->nnYLen = input["nnYLen"].get(); - testAssert(nnOutput->nnXLen >= 2 && nnOutput->nnXLen <= NNPos::MAX_BOARD_LEN); - testAssert(nnOutput->nnYLen >= 2 && nnOutput->nnYLen <= NNPos::MAX_BOARD_LEN); + testAssert(nnOutput->nnXLen >= 2 && nnOutput->nnXLen <= NNPos::MAX_BOARD_LEN_X); + testAssert(nnOutput->nnYLen >= 2 && nnOutput->nnYLen <= NNPos::MAX_BOARD_LEN_Y); std::vector whiteOwnerMap = input["whiteOwnerMap"].get>(); testAssert(whiteOwnerMap.size() == nnOutput->nnXLen*nnOutput->nnYLen); nnOutput->whiteOwnerMap = new float[nnOutput->nnXLen*nnOutput->nnYLen]; diff --git a/cpp/tests/testnninputs.cpp b/cpp/tests/testnninputs.cpp index 1855d8889..a77e99246 100644 --- a/cpp/tests/testnninputs.cpp +++ b/cpp/tests/testnninputs.cpp @@ -14,20 +14,7 @@ static void printNNInputHWAndBoard( ostream& out, int inputsVersion, const Board& board, const BoardHistory& hist, int nnXLen, int nnYLen, bool inputsUseNHWC, T* row, int c ) { - int numFeatures; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V3; - else if(inputsVersion == 4) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V4; - else if(inputsVersion == 5) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V5; - else if(inputsVersion == 6) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V6; - else if(inputsVersion == 7) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V7; - else - testAssert(false); + const int numFeatures = NNInputs::getNumberOfSpatialFeatures(inputsVersion, false); out << "Channel: " << c << endl; @@ -67,20 +54,7 @@ static void printNNInputHWAndBoard( template static void printNNInputGlobal(ostream& out, int inputsVersion, T* row, int c) { - int numFeatures; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V3; - else if(inputsVersion == 4) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V4; - else if(inputsVersion == 5) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V5; - else if(inputsVersion == 6) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V6; - else if(inputsVersion == 7) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V7; - else - testAssert(false); + const int numFeatures = NNInputs::getNumberOfGlobalFeatures(inputsVersion, false); (void)numFeatures; out << "Channel: " << c; @@ -117,6 +91,21 @@ static double finalScoreIfGameEndedNow(const BoardHistory& baseHist, const Board //================================================================================================================== //================================================================================================================== +static Hash128 fillRowAndGetHash( + const int version, + const Board& board, const BoardHistory& hist, + const Player nextPla, + const MiscNNInputParams& nnInputParams, + const int nnXLen, + const int nnYLen, + const bool inputsUseNHWC, + float* rowBin, + float* rowGlobal + ) { + const Hash128 hash = NNInputs::getHash(board,hist,nextPla,nnInputParams); + NNInputs::fillRowVN(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + return hash; +} void Tests::runNNInputsV3V4Tests() { cout << "Running NN inputs V3V4V5V6 tests" << endl; @@ -124,59 +113,10 @@ void Tests::runNNInputsV3V4Tests() { out << std::setprecision(5); auto allocateRows = [](int version, int nnXLen, int nnYLen, int& numFeaturesBin, int& numFeaturesGlobal, float*& rowBin, float*& rowGlobal) { - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(version == 3) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V3; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V3; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V3 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V3]; - } - else if(version == 4) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V4; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V4; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V4 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V4]; - } - else if(version == 5) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V5; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V5; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V5 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V5]; - } - else if(version == 6) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V6; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V6; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V6 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V6]; - } - else if(version == 7) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V7; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V7; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V7 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V7]; - } - else - testAssert(false); - }; - - auto fillRows = [](int version, Hash128& hash, - Board& board, const BoardHistory& hist, Player nextPla, MiscNNInputParams nnInputParams, int nnXLen, int nnYLen, bool inputsUseNHWC, - float* rowBin, float* rowGlobal) { - hash = NNInputs::getHash(board,hist,nextPla,nnInputParams); - - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(version == 3) - NNInputs::fillRowV3(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 4) - NNInputs::fillRowV4(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 5) - NNInputs::fillRowV5(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 6) - NNInputs::fillRowV6(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 7) - NNInputs::fillRowV7(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else - testAssert(false); + numFeaturesBin = NNInputs::getNumberOfSpatialFeatures(version, false); + numFeaturesGlobal = NNInputs::getNumberOfGlobalFeatures(version, false); + rowBin = new float[numFeaturesBin * nnXLen * nnYLen]; + rowGlobal = new float[numFeaturesGlobal]; }; static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); @@ -195,12 +135,11 @@ void Tests::runNNInputsV3V4Tests() { for(int version = minVersion; version <= maxVersion; version++) { cout << "VERSION " << version << endl; - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = Rules::getTrompTaylorish(); initialRules = sgf->getRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; i= 6; size--) { - Board board = Board(size,size); Player nextPla = P_BLACK; vector rules = { @@ -595,12 +527,13 @@ xxx..xx for(int c = 0; cgetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; int nnXLen = 6; int nnYLen = 6; @@ -664,10 +596,9 @@ xxx..xx } out << endl; bool inputsUseNHWC = true; - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); out << "Pass Hist Channels: "; for(int c = 0; c<5; c++) out << rowGlobal[c] << " "; @@ -709,12 +640,11 @@ xxx..xx if(version == 5) continue; cout << "VERSION " << version << endl; - Board board; Player nextPla; - BoardHistory hist; Rules initialRules; initialRules = sgf->getRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; int nnXLen = 13; @@ -735,10 +665,9 @@ xxx..xx if(i == 163) { out << "Move " << i << endl; bool inputsUseNHWC = true; - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,18); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,19); } @@ -793,10 +722,9 @@ o.xoo.x auto run = [&](bool inputsUseNHWC) { Player nextPla = hist.moveHistory.size() > 0 ? getOpp(hist.moveHistory[hist.moveHistory.size()-1].pla) : hist.initialPla; - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); out << hash << endl; printNNInputGlobal(out,version,rowGlobal,5); int c = 18; @@ -887,16 +815,15 @@ o.xoo.x float* rowGlobal; allocateRows(version,nnXLen,nnYLen,numFeaturesBin,numFeaturesGlobal,rowBin,rowGlobal); - Board board = Board(9,1); Player nextPla = P_BLACK; Rules initialRules = Rules::getSimpleTerritory(); + Board board = Board(9,1,initialRules); BoardHistory hist(board,nextPla,initialRules,0); auto run = [&](bool inputsUseNHWC) { - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,9); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,10); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,11); @@ -960,20 +887,19 @@ o.xoo.x auto test = [&](const Board& board, const BoardHistory& hist, Player nextPla) { bool inputsUseNHWC = true; - Hash128 hash; Board b = board; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); for(int c = 0; cgetRulesOrFailAllowUnspecified(rulesToUse); - sgf->setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; int nnXLen = 9; int nnYLen = 9; @@ -1426,12 +1351,11 @@ ooxooxo for(size_t i = 0; i sgf = CompactSgf::parse(sgfStr); - Board board; - BoardHistory hist; Player nextPla = P_BLACK; int turnIdxToSetup = (int)sgf->moves.size(); Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdxToSetup); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdxToSetup); string expected = R"%%( HASH: EB867913318513FD9DE98EDE86AE8CE0 A B C D E F G H J K L M @@ -5272,20 +5270,49 @@ Last moves pass pass pass pass H7 G9 F9 H7 Rand baseRand(name); constexpr int numBoards = 6; - Board boards[numBoards] = { - Board(2,2), - Board(5,1), - Board(6,1), - Board(2,3), - Board(4,2), - }; for(int i = 0; i getMultiGameSize9Data(); diff --git a/cpp/tests/testscore.cpp b/cpp/tests/testscore.cpp index d5432a186..fd2ce0aeb 100644 --- a/cpp/tests/testscore.cpp +++ b/cpp/tests/testscore.cpp @@ -150,7 +150,7 @@ xxxxx for(int b = 0; b<5; b++) { int xSizes[5] = {9, 13, 13, 13, 19}; int ySizes[5] = {9, 9, 13, 19, 19}; - Board board(xSizes[b],ySizes[b]); + Board board(xSizes[b],ySizes[b],Rules::DEFAULT_GO); cout << "center " << center << " scale " << scale << " x " << xSizes[b] << " y " << ySizes[b] << endl; for(int stdev = 0; stdev <= 5; stdev++) { for(double d = -8.0; d<=8.0; d += 0.5) { diff --git a/cpp/tests/testsearch.cpp b/cpp/tests/testsearch.cpp index 1c86d2351..4778dd207 100644 --- a/cpp/tests/testsearch.cpp +++ b/cpp/tests/testsearch.cpp @@ -184,7 +184,7 @@ void Tests::runSearchTests(const string& modelFile, bool inputsNHWC, bool useNHW const bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runBasicPositions(nnEval, logger); delete nnEval; diff --git a/cpp/tests/testsearchcommon.cpp b/cpp/tests/testsearchcommon.cpp index 8a9dbc551..a32013063 100644 --- a/cpp/tests/testsearchcommon.cpp +++ b/cpp/tests/testsearchcommon.cpp @@ -180,11 +180,9 @@ void TestSearchCommon::runBotOnPosition(AsyncBot* bot, Board board, Player nextP void TestSearchCommon::runBotOnSgf(AsyncBot* bot, const string& sgfStr, const Rules& defaultRules, int turnIdx, float overrideKomi, TestSearchOptions opts) { std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(defaultRules); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx);; hist.setKomi(overrideKomi); runBotOnPosition(bot,board,nextPla,hist,opts); } @@ -235,7 +233,8 @@ NNEvaluator* TestSearchCommon::startNNEval( gpuIdxByServerThread, nnRandSeed, nnRandomize, - defaultSymmetry + defaultSymmetry, + false // TODO: Fix for Dots game ); nnEval->spawnServerThreads(); diff --git a/cpp/tests/testsearchmisc.cpp b/cpp/tests/testsearchmisc.cpp index 584cf26f2..fbcd98106 100644 --- a/cpp/tests/testsearchmisc.cpp +++ b/cpp/tests/testsearchmisc.cpp @@ -128,11 +128,9 @@ void Tests::runNNOnManyPoses(const string& modelFile, bool inputsNHWC, bool useN vector policyProbs; for(int turnIdx = 0; turnIdxmoves.size(); turnIdx++) { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); nnEval->evaluate(board,hist,nextPla,nnInputParams,buf,skipCache,includeOwnerMap); winProbs.push_back(buf.result->whiteWinProb); @@ -204,9 +202,7 @@ void Tests::runNNBatchingTest(const string& modelFile, bool inputsNHWC, bool use Rand rand("runNNBatchingTest"); std::unique_ptr sgf = CompactSgf::parse(sgfStr); for(int turnIdx = 0; turnIdxmoves.size(); turnIdx++) { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules; initialRules.koRule = rand.nextBool(0.5) ? Rules::KO_SIMPLE : rand.nextBool(0.5) ? Rules::KO_POSITIONAL : Rules::KO_SITUATIONAL; initialRules.scoringRule = rand.nextBool(0.5) ? Rules::SCORING_AREA : Rules::SCORING_TERRITORY; @@ -215,7 +211,7 @@ void Tests::runNNBatchingTest(const string& modelFile, bool inputsNHWC, bool use initialRules.hasButton = initialRules.scoringRule == Rules::SCORING_AREA && rand.nextBool(0.5); initialRules.whiteHandicapBonusRule = rand.nextBool(0.5) ? Rules::WHB_ZERO : rand.nextBool(0.5) ? Rules::WHB_N : Rules::WHB_N_MINUS_ONE; initialRules.komi = 7.5f + rand.nextInt(-10,10) * 0.5f; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); items.push_back(NNBatchingTestItem(board,hist,nextPla)); } }; diff --git a/cpp/tests/testsearchnonn.cpp b/cpp/tests/testsearchnonn.cpp index 862bce25c..1aee45db5 100644 --- a/cpp/tests/testsearchnonn.cpp +++ b/cpp/tests/testsearchnonn.cpp @@ -36,7 +36,7 @@ void Tests::runNNLessSearchTests() { cout << "Basic search with debugSkipNeuralNet and chosen move randomization" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 100; Search* search = new Search(params, nnEval, &logger, "autoSearchRandSeed"); @@ -119,7 +119,7 @@ void Tests::runNNLessSearchTests() { cout << "Testing preservation of search tree across moves" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 100; params.cpuctExploration *= 2; @@ -251,7 +251,7 @@ o..oo.x { cout << "First with no pruning" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; Search* search = new Search(params, nnEval, &logger, "autoSearchRandSeed3"); @@ -274,7 +274,7 @@ o..oo.x { cout << "Next, with rootPruneUselessMoves" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.rootPruneUselessMoves = true; @@ -310,7 +310,7 @@ o..oo.x { cout << "Searching on the opponent, the move before" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1b",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1b",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.rootPruneUselessMoves = true; @@ -382,7 +382,7 @@ o..o.oo { cout << "First with no pruning" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.dynamicScoreUtilityFactor = 0.5; @@ -898,7 +898,7 @@ xx......x Rand rand("noiseVisualize"); auto run = [&](int xSize, int ySize) { - Board board(xSize,ySize); + const Board board(xSize,ySize,Rules::DEFAULT_GO); int nnXLen = 19; int nnYLen = 19; float sum = 0.0; @@ -1601,7 +1601,7 @@ xxxxxxxxx cout << "Testing coherence of search tree recursive walking" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 1000; params.dynamicScoreUtilityFactor = 3.0; @@ -2534,9 +2534,10 @@ x.x.x std::map,int> boardSizeDistribution; for(int i = 0; i<100000; i++) { - Board board; + Rules rules = Rules::DEFAULT_GO; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); ExtraBlackAndKomi extraBlackAndKomi; OtherGameProperties otherGameProps; gameInit.createGame(board,pla,hist,extraBlackAndKomi,NULL,PlaySettings(),otherGameProps,NULL); diff --git a/cpp/tests/testsearchv3.cpp b/cpp/tests/testsearchv3.cpp index 97f933196..5817e8aa6 100644 --- a/cpp/tests/testsearchv3.cpp +++ b/cpp/tests/testsearchv3.cpp @@ -24,11 +24,9 @@ static void runOwnershipAndMisc(NNEvaluator* nnEval, NNEvaluator* nnEval11, NNEv string sgfStr = "(;FF[4]CA[UTF-8]KM[7.5];B[pp];W[pc];B[cd];W[dq];B[ed];W[pe];B[co];W[cp];B[do];W[fq];B[ck];W[qn];B[qo];W[pn];B[np];W[qj];B[jc];W[lc];B[je];W[lq];B[mq];W[lp];B[ek];W[qq];B[pq];W[ro];B[rp];W[qp];B[po];W[rq];B[rn];W[sp];B[rm];W[ql];B[on];W[om];B[nn];W[nm];B[mn];W[ip];B[mm])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 40); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 40); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -68,11 +66,9 @@ static void runOwnershipAndMisc(NNEvaluator* nnEval, NNEvaluator* nnEval11, NNEv string sgfStr = "(;FF[4]CA[UTF-8]SZ[11]KM[7.5];B[ci];W[ic];B[ih];W[hi];B[ii];W[ij];B[jj];W[gj];B[ik];W[di];B[hh];W[ch];B[dc];W[cc];B[cb];W[cd];B[eb];W[dd];B[ed];W[ee];B[fd];W[bb];B[ba];W[ab];B[gb];W[je];B[ib];W[jb];B[jc];W[jd];B[hc];W[id];B[dh];W[cg];B[dj];W[ei];B[bi];W[ia];B[hb];W[fg];B[hj];W[eh];B[ej])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 43); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 43); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -524,9 +520,9 @@ void Tests::runSearchTestsV3(const string& modelFile, bool inputsNHWC, bool useN const bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); NNEvaluator* nnEval11 = startNNEval(modelFile,logger,"",11,11,symmetry,inputsNHWC,useNHWC,useFP16,false,false); - NNEvaluator* nnEvalPTemp = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEvalPTemp = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runOwnershipAndMisc(nnEval,nnEval11,nnEvalPTemp,logger); delete nnEval; delete nnEval11; diff --git a/cpp/tests/testsearchv8.cpp b/cpp/tests/testsearchv8.cpp index 431ba103c..3f51ce991 100644 --- a/cpp/tests/testsearchv8.cpp +++ b/cpp/tests/testsearchv8.cpp @@ -22,11 +22,9 @@ static void runV8TestsSize9(NNEvaluator* nnEval, NNEvaluator* nnEval9, NNEvaluat string sgfStr = "(;FF[4]GM[1]SZ[9]HA[0]KM[7]RU[stonescoring];B[ef];W[ed];B[ge])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 3); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 3); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; @@ -58,11 +56,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[19]KM[6.5];B[dd];W[qd];B[pq];W[dp];B[oc];W[pe];B[fq];W[jp];B[ph];W[cf];B[ck])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 11); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 11); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; @@ -92,11 +88,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[19]KM[6.5];B[dd];W[qd];B[od];W[pq];B[dq];W[do];B[eo];W[oe])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams params = SearchParams::forTestsV1(); params.rootNumSymmetriesToSample = 8; @@ -127,11 +121,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[AGA]SZ[19]KM[7.0];B[dd];W[pd];B[dp];W[pp];B[qc];W[qd];B[pc];W[nc];B[nb])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams paramsA = SearchParams::forTestsV1(); SearchParams paramsB = SearchParams::forTestsV1(); @@ -528,11 +520,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) std::unique_ptr sgf = CompactSgf::parse(sgfStr); { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 24); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 24); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -543,11 +533,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 32); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 32); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -558,11 +546,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 124); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 124); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -583,11 +569,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) std::unique_ptr sgf = CompactSgf::parse(sgfStr); { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 29); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 29); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -598,11 +582,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 83); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 83); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -626,11 +608,9 @@ static void runMoreV8Tests(NNEvaluator* nnEval, Logger& logger) string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[9]KM[0];B[dc];W[ef];B[df];W[de];B[dg];W[eg];B[eh];W[fh];B[ee])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 20; diff --git a/cpp/tests/testsearchv9.cpp b/cpp/tests/testsearchv9.cpp index 2bc4533c4..57bb7cc12 100644 --- a/cpp/tests/testsearchv9.cpp +++ b/cpp/tests/testsearchv9.cpp @@ -554,9 +554,9 @@ oo...ooo { cout << "Single symmetry and full symmetry raw nets" << endl; - Board board(19,19); Player nextPla = P_BLACK; Rules rules = Rules::parseRules("Japanese"); + Board board(19,19,rules); rules.komi = -4; BoardHistory hist(board,nextPla,rules,0); hist.makeBoardMoveAssumeLegal(board,Location::ofString("A1",board),P_BLACK,NULL); @@ -762,7 +762,7 @@ void Tests::runSearchTestsV9(const string& modelFile, bool inputsNHWC, bool useN Logger logger(nullptr, logToStdout, logToStderr, logTime); int symmetry = 4; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runV9Positions(nnEval, logger); delete nnEval; diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index 4492fb00e..a5a658674 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -1,7 +1,6 @@ #include "../tests/tests.h" #include -#include #include "../dataio/sgf.h" #include "../search/asyncbot.h" @@ -16,23 +15,42 @@ void Tests::runSgfTests() { auto parseAndPrintSgfLinear = [&out](const string& sgfStr) { std::unique_ptr sgf = CompactSgf::parse(sgfStr); + if (sgf->isDots) { + out << "Dots game" << endl; + } out << "xSize " << sgf->xSize << endl; out << "ySize " << sgf->ySize << endl; out << "depth " << sgf->depth << endl; - out << "komi " << sgf->getRulesOrFailAllowUnspecified(Rules()).komi << endl; + const Rules rules = sgf->getRulesOrFailAllowUnspecified(Rules::getDefaultOrTrompTaylorish(sgf->isDots)); + out << "komi " << rules.komi << endl; - Board board; - BoardHistory hist; - Rules rules; Player pla; - rules = sgf->getRulesOrFailAllowUnspecified(rules); - sgf->setupInitialBoardAndHist(rules,board,pla,hist); - out << "placements" << endl; - for(int i = 0; i < sgf->placements.size(); i++) { - Move move = sgf->placements[i]; - out << PlayerIO::colorToChar(move.pla) << " " << Location::toString(move.loc,board) << endl; + const BoardHistory hist = sgf->setupInitialBoardAndHist(rules, pla); + const Board& board = hist.initialBoard; + + bool randomized; + vector startPosMoves; + vector remainingPlacementMoves; + const int recognizedStartPos = Rules::recognizeStartPos(sgf->placements, board.x_size, board.y_size, startPosMoves, randomized, &remainingPlacementMoves); + testAssert(recognizedStartPos == rules.startPos); + testAssert(randomized == rules.startPosIsRandom); + + if (recognizedStartPos != Rules::START_POS_EMPTY) { + out << "startPos " << Rules::writeStartPosRule(recognizedStartPos); + if (randomized) { + out << " (randomized)"; + } + out << endl; } + + if (!remainingPlacementMoves.empty()) { + out << "placements" << endl; + for (const auto placementMove : remainingPlacementMoves) { + out << PlayerIO::colorToChar(placementMove.pla) << " " << Location::toString(placementMove.loc, board) << endl; + } + } + out << "moves" << endl; for(int i = 0; i < sgf->moves.size(); i++) { Move move = sgf->moves[i]; @@ -40,29 +58,26 @@ void Tests::runSgfTests() { } out << "Initial board hist " << endl; - out << "pla " << PlayerIO::playerToString(pla) << endl; + out << "pla " << PlayerIO::playerToString(pla,rules.isDots) << endl; hist.printDebugInfo(out,board); - sgf->setupBoardAndHistAssumeLegal(rules,board,pla,hist,sgf->moves.size()); + auto [finalHist, finalBoard] = sgf->setupBoardAndHistAssumeLegal(rules, pla, sgf->moves.size()); out << "Final board hist " << endl; - out << "pla " << PlayerIO::playerToString(pla) << endl; - hist.printDebugInfo(out,board); + out << "pla " << PlayerIO::playerToString(pla,rules.isDots) << endl; + finalHist.printDebugInfo(out,finalBoard); { //Test SGF writing roundtrip. //This is not exactly holding if there is pass for ko, but should be good in all other cases ostringstream out2; - WriteSgf::writeSgf(out2,"foo","bar",hist,NULL,false,false); + WriteSgf::writeSgf(out2,"foo","bar",finalHist,NULL,false,false); std::unique_ptr sgf2 = CompactSgf::parse(out2.str()); - Board board2; - BoardHistory hist2; - Rules rules2; Player pla2; - rules2 = sgf2->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(rules2,board2,pla2,hist2,sgf2->moves.size()); + const Rules rules2 = sgf2->getRulesOrFail(); + auto [hist2, board2] = sgf->setupBoardAndHistAssumeLegal(rules2, pla2, sgf2->moves.size()); testAssert(rules2 == rules); - testAssert(board2.pos_hash == board.pos_hash); - testAssert(hist2.moveHistory.size() == hist.moveHistory.size()); + testAssert(board2.pos_hash == finalBoard.pos_hash); + testAssert(hist2.moveHistory.size() == finalHist.moveHistory.size()); } }; @@ -85,7 +100,7 @@ void Tests::runSgfTests() { } out << "handicapValue " << sgf->getHandicapValue() << endl; - out << "sgfWinner " << PlayerIO::playerToString(sgf->getSgfWinner()) << endl; + out << "sgfWinner " << PlayerIO::playerToString(sgf->getSgfWinner(), sgf->isDotsGame()) << endl; out << "firstPlayerColor " << PlayerIO::colorToChar(sgf->getFirstPlayerColor()) << endl; out << "black rank " << sgf->getRank(P_BLACK) << endl; @@ -119,7 +134,7 @@ void Tests::runSgfTests() { sgf->getPlacements(placements, xySize.x, xySize.y); out << "placements " << placements.size() << endl; for(const Move& move: placements) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y) << " "; + out << PlayerIO::playerToString(move.pla, sgf->isDotsGame()) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -127,7 +142,7 @@ void Tests::runSgfTests() { sgf->getMoves(moves, xySize.x, xySize.y); out << "moves " << moves.size() << endl; for(const Move& move: moves) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y) << " "; + out << PlayerIO::playerToString(move.pla, sgf->isDotsGame()) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -151,6 +166,79 @@ void Tests::runSgfTests() { ); }; + { + const char* name = "Basic Dots Sgf parse test"; + string sgfStr = "(;FF[4]GM[40]CA[UTF-8]AP[katago]SZ[10:8]AB[ed][fe][ef]AW[ee][fd];W[de];B[df];W[hd];B[ce];W[hf];B[cd];W[ff];B[dc];W[cf];B[hb];W[ic];B[db];W[gg];B[da];W[bg];B[])"; + + parseAndPrintSgfLinear(sgfStr); + string expected = R"( +Dots game +xSize 10 +ySize 8 +depth 17 +komi 0 +startPos CROSS +placements +X 5-3 +moves +O 4-4 +X 4-3 +O 8-5 +X 3-4 +O 8-3 +X 3-5 +O 6-3 +X 4-6 +O 3-3 +X 8-7 +O 9-6 +X 4-7 +O 7-2 +X 4-8 +O 2-2 +X ground +Initial board hist +pla Player2 +HASH: 42AC4303D65557034CC3593CB26EA615 + 1 2 3 4 5 6 7 8 9 10 + 8 . . . . . . . . . . + 7 . . . . . . . . . . + 6 . . . . . . . . . . + 5 . . . . X O . . . . + 4 . . . . O X . . . . + 3 . . . . X . . . . . + 2 . . . . . . . . . . + 1 . . . . . . . . . . + + +Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 +White bonus score 0 +Presumed next pla Player2 +Game result 0 Empty 0 0 0 0 +Last moves +Final board hist +pla Player2 +HASH: AB87C4395AA2D7E5D7B069ACBFA701D5 + 1 2 3 4 5 6 7 8 9 10 + 8 . . . X . . . . . . + 7 . . . X . . . x . . + 6 . . . X . . . . O . + 5 . . X ' X O . O . . + 4 . . X o o X . . . . + 3 . . O X X O . O . . + 2 . O . . . . O . . . + 1 . . . . . . . . . . + + +Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 +White bonus score 0 +Presumed next pla Player2 +Game result 1 Player1 -1 1 0 0 +Last moves 4-4 4-3 8-5 3-4 8-3 3-5 6-3 4-6 3-3 8-7 9-6 4-7 7-2 4-8 2-2 ground +)"; + expect(name,out,expected); + } + //============================================================================ { const char* name = "Basic Sgf parse test"; @@ -582,7 +670,6 @@ xSize 17 ySize 3 depth 5 komi -6.5 -placements moves X F1 O C1 @@ -753,7 +840,6 @@ xSize 5 ySize 5 depth 13 komi 24 -placements moves X C3 O C4 @@ -848,7 +934,6 @@ xSize 5 ySize 5 depth 7 komi 24 -placements moves X C3 X B4 @@ -920,7 +1005,7 @@ Last moves C3 B4 C4 D4 D3 C2 B3 //============================================================================ - if(Board::MAX_LEN >= 37) + if constexpr(std::min(Board::MAX_LEN_X, Board::MAX_LEN_Y) >= 37) { const char* name = "Giant Sgf parse test"; string sgfStr = "(;GM[1]FF[4]CA[UTF-8]ST[2]RU[Chinese]SZ[37]KM[0.00];B[dd];W[Hd];B[HH];W[dH];B[dG];W[eG];B[eF];W[Gd];B[Ge];W[He];B[ee];W[GG];B[ss])"; @@ -931,7 +1016,6 @@ xSize 37 ySize 37 depth 14 komi 0 -placements moves X D34 O AJ34 @@ -994,7 +1078,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koSIMPLEscoreAREAtaxNONEsui0whbNkomi0 +Rules koSIMPLEscoreAREAtaxNONEsui0whbNfpok1komi0 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -1051,7 +1135,7 @@ Encore phase 0 Turns this phase 13 Approx valid turns this phase 13 Approx consec valid turns this game 13 -Rules koSIMPLEscoreAREAtaxNONEsui0whbNkomi0 +Rules koSIMPLEscoreAREAtaxNONEsui0whbNfpok1komi0 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -1421,7 +1505,7 @@ void Tests::runSgfFileTests() { testAssert(sgf->getXYSize().y == 19); testAssert(sgf->getKomiOrFail() == 6.5f); testAssert(sgf->hasRules() == true); - testAssert(sgf->getRulesOrFail().equalsIgnoringKomi(Rules::parseRules("chinese"))); + testAssert(sgf->getRulesOrFail().equalsIgnoringSgfDefinedProps(Rules::parseRules("chinese"))); testAssert(sgf->getHandicapValue() == 2); testAssert(sgf->getSgfWinner() == C_EMPTY); testAssert(sgf->getPlayerName(P_BLACK) == "testname1"); diff --git a/cpp/tests/testsymmetries.cpp b/cpp/tests/testsymmetries.cpp index 5703ccf35..f686f9269 100644 --- a/cpp/tests/testsymmetries.cpp +++ b/cpp/tests/testsymmetries.cpp @@ -413,7 +413,7 @@ x.xxo.... Loc symLocComb = SymmetryHelpers::getSymLoc(loc,board,symmetryComposed); Loc symLocCombManual = SymmetryHelpers::getSymLoc(SymmetryHelpers::getSymLoc(loc,board,symmetry1),SymmetryHelpers::getSymBoard(board,symmetry1),symmetry2); out << "Symmetry " << symmetry1 << " + " << symmetry2 << " = " << symmetryComposed << endl; - testAssert(symBoardCombManual.isEqualForTesting(symBoardComb,true,true)); + testAssert(symBoardCombManual.isEqualForTesting(symBoardComb)); testAssert(symLocComb == symLocCombManual); } } @@ -588,7 +588,7 @@ x.xxo.... out << "SYMMETRY " << symmetry << endl; out << boardA << endl; out << boardB << endl; - testAssert(boardA.isEqualForTesting(boardB,true,true)); + testAssert(boardA.isEqualForTesting(boardB)); } string expected = R"%%( SYMMETRY 0 diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index 7c9e3292f..a2131fbe2 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -9,6 +9,15 @@ using namespace std; using namespace TestCommon; +static TrainingDataWriter createTestTrainingDataWriter( + const int inputVersion, + const int nnXLen, + const int nnYLen, + const string& seed, + const int onlyWriteEvery) { + return TrainingDataWriter(string(), &cout, inputVersion, 256, 1.0f, nnXLen, nnYLen, seed, onlyWriteEvery); +} + static NNEvaluator* startNNEval( const string& modelFile, const string& seed, Logger& logger, int defaultSymmetry, bool inputsUseNHWC, bool useNHWC, bool useFP16 @@ -16,8 +25,8 @@ static NNEvaluator* startNNEval( const string& modelName = modelFile; vector gpuIdxByServerThread = {0}; int maxBatchSize = 16; - int nnXLen = NNPos::MAX_BOARD_LEN; - int nnYLen = NNPos::MAX_BOARD_LEN; + int nnXLen = NNPos::MAX_BOARD_LEN_X; + int nnYLen = NNPos::MAX_BOARD_LEN_Y; bool requireExactNNLen = false; int nnCacheSizePowerOfTwo = 16; int nnMutexPoolSizePowerOfTwo = 12; @@ -51,7 +60,8 @@ static NNEvaluator* startNNEval( gpuIdxByServerThread, seed, nnRandomize, - defaultSymmetry + defaultSymmetry, + false // TODO: Fix for Dots Game ); nnEval->spawnServerThreads(); @@ -66,13 +76,9 @@ void Tests::runTrainingWriteTests() { cout << "Running training write tests" << endl; NeuralNet::globalInitialize(); - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 5; - - const bool logToStdout = true; - const bool logToStderr = false; - const bool logTime = false; + constexpr bool logToStdout = true; + constexpr bool logToStderr = false; + constexpr bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); auto run = [&]( @@ -82,7 +88,7 @@ void Tests::runTrainingWriteTests() { int boardXLen, int boardYLen, bool cheapLongSgf ) { - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, nnXLen, nnYLen, seedBase + "dwriter", 5); NNEvaluator* nnEval = startNNEval("/dev/null",seedBase+"nneval",logger,0,inputsNHWC,useNHWC,false); @@ -96,7 +102,7 @@ void Tests::runTrainingWriteTests() { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(boardXLen,boardYLen); + Board initialBoard(boardXLen,boardYLen,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -225,7 +231,7 @@ void Tests::runSelfplayInitTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -406,7 +412,7 @@ void Tests::runMoreSelfplayTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; if(testHint) { @@ -586,7 +592,7 @@ void Tests::runMoreSelfplayTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -932,9 +938,9 @@ xxxxxxxx. FinishedGameData* data = gameRunner->runGame(seed, botSpec, botSpec, forkData, NULL, logger, shouldStop, shouldPause, nullptr, nullptr, nullptr); cout << data->startHist.rules << endl; cout << "Start moves size " << data->startHist.moveHistory.size() - << " Start pla " << PlayerIO::playerToString(data->startPla) + << " Start pla " << PlayerIO::playerToString(data->startPla,data->startHist.rules.isDots) << " XY " << data->startBoard.x_size << " " << data->startBoard.y_size - << " Extra black " << data->numExtraBlack + << " Extra " << PlayerIO::playerToString(P_BLACK, data->startHist.rules.isDots) << " " << data->numExtraBlack << " Draw equiv " << data->drawEquivalentWinsForWhite << " Mode " << data->mode << " BeganInEncorePhase " << data->beganInEncorePhase @@ -1008,10 +1014,9 @@ xxxxxxxx. vector moves = sgf->moves; Rules initialRules = Rules::parseRules("chinese"); - Board board; Player nextPla; - BoardHistory hist; - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; for(size_t i = 0; i(); startPosSample.initialTurnNumber = 0; @@ -1173,10 +1176,10 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; - startPosSample.board = Board(9,9); + startPosSample.board = Board(9,9,rules); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 40; @@ -1203,9 +1206,6 @@ xxxxxxxx. cout << "====================================================================================================" << endl; cout << "Testing no result" << endl; - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 1; int inputsVersion = 7; SearchParams params = SearchParams::forTestsV2(); @@ -1272,9 +1272,9 @@ xxxxxxxx. botSpec.nnEval = nnEval; botSpec.baseParams = params; - string seed = "seed-testing-temperature"; - { + string seed = "seed-testing-temperature"; + int debugOnlyWriteEvery = 1; cout << "Turn number initial 0 selfplay with high temperatures" << endl; nnEval->clearCache(); nnEval->clearStats(); @@ -1284,7 +1284,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board::parseBoard(9,9,R"%%( @@ -2248,7 +2248,7 @@ void Tests::runSelfplayStatTestsWithNN(const string& modelFile) { playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2300,7 +2300,7 @@ void Tests::runSelfplayStatTestsWithNN(const string& modelFile) { playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2454,7 +2454,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2509,7 +2509,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2563,7 +2563,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2618,7 +2618,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2676,7 +2676,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2723,7 +2723,6 @@ oox.x.... NeuralNet::globalCleanup(); } - void Tests::runSekiTrainWriteTests(const string& modelFile) { bool inputsNHWC = true; bool useNHWC = false; @@ -2732,9 +2731,6 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { cout << "Running test for how a seki gets recorded" << endl; NeuralNet::globalInitialize(); - int nnXLen = 13; - int nnYLen = 13; - const bool logToStdout = true; const bool logToStderr = false; const bool logTime = false; @@ -2744,10 +2740,8 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { auto run = [&](const string& sgfStr, const string& seedBase, const Rules& rules) { int inputsVersion = 6; - int maxRows = 256; - double firstFileMinRandProp = 1.0; int debugOnlyWriteEvery = 1000; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 13, 13, seedBase + "dwriter", debugOnlyWriteEvery); nnEval->clearCache(); nnEval->clearStats(); @@ -2763,16 +2757,14 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { botSpec.baseParams = params; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board initialBoard; Player initialPla; - BoardHistory initialHist; ExtraBlackAndKomi extraBlackAndKomi; extraBlackAndKomi.extraBlack = 0; extraBlackAndKomi.komiMean = rules.komi; extraBlackAndKomi.komiStdev = 0; int turnIdx = (int)sgf->moves.size(); - sgf->setupBoardAndHistAssumeLegal(rules,initialBoard,initialPla,initialHist,turnIdx); + auto [initialHist, initialBoard] = sgf->setupBoardAndHistAssumeLegal(rules, initialPla, turnIdx); bool doEndGameIfAllPassAlive = true; bool clearBotAfterSearch = true; @@ -2854,7 +2846,7 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { vector buf; vector isAlive = PlayUtils::computeAnticipatedStatusesWithOwnership(bot,board,hist,pla,numVisits,buf); testAssert(bot->alwaysIncludeOwnerMap == false); - cout << "Search assumes " << PlayerIO::playerToString(pla) << " first" << endl; + cout << "Search assumes " << PlayerIO::playerToString(pla,hist.rules.isDots) << " first" << endl; cout << "Rules " << hist.rules << endl; cout << board << endl; for(int y = 0; ysetDoRandomize(false); @@ -247,7 +247,7 @@ NNEvaluator* TinyModelTest::runTinyModelTest(const string& baseDir, Logger& logg const string expectedSha256 = ""; NNEvaluator* nnEval = Setup::initializeNNEvaluator( "tinyModel",tmpModelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,maxBatchSize,requireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,maxBatchSize,requireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); nnEval->setDoRandomize(false); @@ -407,7 +407,7 @@ NNEvaluator* TinyModelTest::runTinyModelTest(const string& baseDir, Logger& logg const string expectedSha256 = ""; NNEvaluator* nnEval = Setup::initializeNNEvaluator( "tinyModel",tmpModelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,maxBatchSize,requireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,maxBatchSize,requireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); nnEval->setDoRandomize(false); diff --git a/python/export_model_pytorch.py b/python/export_model_pytorch.py index c006085f2..7b642f4ef 100644 --- a/python/export_model_pytorch.py +++ b/python/export_model_pytorch.py @@ -83,7 +83,11 @@ def writestr(s): version = 15 writeln(model_name) - writeln(version) + writestr(str(version)) + writestr("," + str(model.pos_len_x)) + writestr("," + str(model.pos_len_y)) + writestr("," + ";".join(c.name for c in model.games)) + writeln("") writeln(modelconfigs.get_num_bin_input_features(model_config)) writeln(modelconfigs.get_num_global_input_features(model_config)) diff --git a/python/forward_model.py b/python/forward_model.py index f919b0a90..f3b6d8b57 100644 --- a/python/forward_model.py +++ b/python/forward_model.py @@ -33,6 +33,8 @@ parser.add_argument('-npz', help='NPZ file to evaluate', required=True) parser.add_argument('-checkpoint', help='Checkpoint to test', required=False) parser.add_argument('-pos-len', help='Spatial length of expected training data', type=int, required=True) + parser.add_argument('-pos-len-x', help='Spatial width of expected training data (-pos-len if undefined)', type=int, required=False) + parser.add_argument('-pos-len-y', help='Spatial height of expected training data (-pos-len if undefined)', type=int, required=False) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-gpu-idx', help='GPU idx', type=int, required=False) @@ -42,6 +44,8 @@ def main(args): npz_file = args["npz"] checkpoint_file = args["checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len use_swa = args["use_swa"] gpu_idx = args["gpu_idx"] @@ -74,7 +78,7 @@ def main(args): # LOAD MODEL --------------------------------------------------------------------- - model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=False) + model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len_x, pos_len_y=pos_len_y, verbose=False) model_config = model.config batch = np.load(npz_file) diff --git a/python/humanslnet_server.py b/python/humanslnet_server.py index ca7b3e33f..93638b1c7 100644 --- a/python/humanslnet_server.py +++ b/python/humanslnet_server.py @@ -21,7 +21,7 @@ def main(): parser.add_argument('-device', help='Device to use, such as cpu or cuda:0', required=True) args = parser.parse_args() - model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device, pos_len=19, verbose=False) + model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device,verbose=False) if swa_model is not None: model = swa_model game_state = None diff --git a/python/katago/game/features.py b/python/katago/game/features.py index 246eefb30..d9451d4d5 100644 --- a/python/katago/game/features.py +++ b/python/katago/game/features.py @@ -6,29 +6,30 @@ from ..train import modelconfigs class Features: - def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): + def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int): self.config = config - self.pos_len = pos_len + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y self.version = modelconfigs.get_version(config) - self.pass_pos = self.pos_len * self.pos_len - self.bin_input_shape = [modelconfigs.get_num_bin_input_features(config), pos_len, pos_len] + self.pass_pos = self.pos_len_x * self.pos_len_y + self.bin_input_shape = [modelconfigs.get_num_bin_input_features(config), pos_len_x, pos_len_y] self.global_input_shape = [modelconfigs.get_num_global_input_features(config)] def xy_to_tensor_pos(self,x,y): - return y * self.pos_len + x + return y * self.pos_len_x + x def loc_to_tensor_pos(self,loc,board): if loc == Board.PASS_LOC: return self.pass_pos - return board.loc_y(loc) * self.pos_len + board.loc_x(loc) + return board.loc_y(loc) * self.pos_len_x + board.loc_x(loc) def tensor_pos_to_loc(self,pos,board): if pos == self.pass_pos: return None - pos_len = self.pos_len + pos_len = self.pos_len_x x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) x = pos % pos_len y = pos // pos_len if x < 0 or x >= x_size or y < 0 or y >= y_size: @@ -38,7 +39,7 @@ def tensor_pos_to_loc(self,pos,board): def sym_tensor_pos(self,pos,symmetry): if pos == self.pass_pos: return pos - pos_len = self.pos_len + pos_len = self.pos_len_x x = pos % pos_len y = pos // pos_len if symmetry >= 4: @@ -61,8 +62,8 @@ def iterLadders(self, board, f): x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) for y in range(y_size): for x in range(x_size): @@ -98,8 +99,8 @@ def fill_row_features(self, board, pla, opp, boards, moves, move_idx, rules, bin x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) assert(len(boards) > 0) assert(board.zobrist == boards[move_idx].zobrist) diff --git a/python/katago/game/gamestate.py b/python/katago/game/gamestate.py index 896fbf079..b81300654 100644 --- a/python/katago/game/gamestate.py +++ b/python/katago/game/gamestate.py @@ -89,15 +89,16 @@ def redo(self): def get_input_features(self, features: Features): bin_input_data = np.zeros(shape=[1]+features.bin_input_shape, dtype=np.float32) global_input_data = np.zeros(shape=[1]+features.global_input_shape, dtype=np.float32) - pos_len = features.pos_len + pos_len_x = features.pos_len_x + pos_len_y = features.pos_len_y pla = self.board.pla opp = Board.get_opp(pla) move_idx = len(self.moves) # fill_row_features assumes N(HW)C order but we actually use NCHW order in the model, so work with it and revert bin_input_data = np.transpose(bin_input_data,axes=(0,2,3,1)) - bin_input_data = bin_input_data.reshape([1,pos_len*pos_len,-1]) + bin_input_data = bin_input_data.reshape([1,pos_len_x*pos_len_y,-1]) features.fill_row_features(self.board,pla,opp,self.boards,self.moves,move_idx,self.rules,bin_input_data,global_input_data,idx=0) - bin_input_data = bin_input_data.reshape([1,pos_len,pos_len,-1]) + bin_input_data = bin_input_data.reshape([1,pos_len_x,pos_len_y,-1]) bin_input_data = np.transpose(bin_input_data,axes=(0,3,1,2)) return bin_input_data, global_input_data @@ -106,7 +107,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non from ..train.model_pytorch import Model, ExtraOutputs with torch.no_grad(): model.eval() - features = Features(model.config, model.pos_len) + features = Features(model.config, model.pos_len_x, model.pos_len_y) bin_input_data, global_input_data = self.get_input_features(features) # Currently we don't actually do any symmetries @@ -195,7 +196,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non elif board.would_be_legal(board.pla,move): moves_and_probs1.append((move,policy1[i])) - ownership_flat = ownership.reshape([features.pos_len * features.pos_len]) + ownership_flat = ownership.reshape([features.pos_len_x * features.pos_len_y]) ownership_by_loc = [] board = self.board for y in range(board.y_size): @@ -207,7 +208,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: ownership_by_loc.append((loc,-ownership_flat[pos])) - scoring_flat = scoring.reshape([features.pos_len * features.pos_len]) + scoring_flat = scoring.reshape([features.pos_len_x * features.pos_len_y]) scoring_by_loc = [] board = self.board for y in range(board.y_size): @@ -219,7 +220,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: scoring_by_loc.append((loc,-scoring_flat[pos])) - futurepos0_flat = futurepos[0,:,:].reshape([features.pos_len * features.pos_len]) + futurepos0_flat = futurepos[0,:,:].reshape([features.pos_len_x * features.pos_len_y]) futurepos0_by_loc = [] board = self.board for y in range(board.y_size): @@ -231,7 +232,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: futurepos0_by_loc.append((loc,-futurepos0_flat[pos])) - futurepos1_flat = futurepos[1,:,:].reshape([features.pos_len * features.pos_len]) + futurepos1_flat = futurepos[1,:,:].reshape([features.pos_len_x * features.pos_len_y]) futurepos1_by_loc = [] board = self.board for y in range(board.y_size): @@ -243,7 +244,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: futurepos1_by_loc.append((loc,-futurepos1_flat[pos])) - seki_flat = seki.reshape([features.pos_len * features.pos_len]) + seki_flat = seki.reshape([features.pos_len_x * features.pos_len_y]) seki_by_loc = [] board = self.board for y in range(board.y_size): @@ -255,7 +256,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: seki_by_loc.append((loc,-seki_flat[pos])) - seki_flat2 = seki2.reshape([features.pos_len * features.pos_len]) + seki_flat2 = seki2.reshape([features.pos_len_x * features.pos_len_y]) seki_by_loc2 = [] board = self.board for y in range(board.y_size): diff --git a/python/katago/train/data_processing_pytorch.py b/python/katago/train/data_processing_pytorch.py index 1ad18c9fd..78a50daeb 100644 --- a/python/katago/train/data_processing_pytorch.py +++ b/python/katago/train/data_processing_pytorch.py @@ -9,17 +9,8 @@ from ..train import modelconfigs -def read_npz_training_data( - npz_files, - batch_size: int, - world_size: int, - rank: int, - pos_len: int, - device, - randomize_symmetries: bool, - include_meta: bool, - model_config: modelconfigs.ModelConfig, -): +def read_npz_training_data(npz_files, batch_size: int, world_size: int, rank: int, pos_len_x: int, pos_len_y: int, device, + randomize_symmetries: bool, include_meta: bool, model_config: modelconfigs.ModelConfig): rand = np.random.default_rng(seed=list(os.urandom(12))) num_bin_features = modelconfigs.get_num_bin_input_features(model_config) num_global_features = modelconfigs.get_num_global_input_features(model_config) @@ -47,10 +38,10 @@ def load_npz_file(npz_file): binaryInputNCHW = np.unpackbits(binaryInputNCHWPacked,axis=2) assert len(binaryInputNCHW.shape) == 3 - assert binaryInputNCHW.shape[2] == ((pos_len * pos_len + 7) // 8) * 8 - binaryInputNCHW = binaryInputNCHW[:,:,:pos_len*pos_len] + assert binaryInputNCHW.shape[2] == ((pos_len_x * pos_len_y + 7) // 8) * 8 + binaryInputNCHW = binaryInputNCHW[:,:, :pos_len_x * pos_len_y] binaryInputNCHW = np.reshape(binaryInputNCHW, ( - binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len, pos_len + binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len_x, pos_len_y )).astype(np.float32) assert binaryInputNCHW.shape[1] == num_bin_features @@ -98,10 +89,10 @@ def load_npz_file(npz_file): if randomize_symmetries: symm = int(rand.integers(0,8)) batch_binaryInputNCHW = apply_symmetry(batch_binaryInputNCHW, symm) - batch_policyTargetsNCMove = apply_symmetry_policy(batch_policyTargetsNCMove, symm, pos_len) + batch_policyTargetsNCMove = apply_symmetry_policy(batch_policyTargetsNCMove, symm, pos_len_x, pos_len_y) batch_valueTargetsNCHW = apply_symmetry(batch_valueTargetsNCHW, symm) if include_qvalues: - batch_qValueTargetsNCMove = apply_symmetry_policy(batch_qValueTargetsNCMove, symm, pos_len) + batch_qValueTargetsNCMove = apply_symmetry_policy(batch_qValueTargetsNCMove, symm, pos_len_x, pos_len_y) batch_binaryInputNCHW = batch_binaryInputNCHW.contiguous() batch_policyTargetsNCMove = batch_policyTargetsNCMove.contiguous() @@ -125,14 +116,14 @@ def load_npz_file(npz_file): yield batch -def apply_symmetry_policy(tensor, symm, pos_len): +def apply_symmetry_policy(tensor, symm, pos_len_x, pos_len_y): """Same as apply_symmetry but also handles the pass index""" batch_size = tensor.shape[0] channels = tensor.shape[1] - tensor_without_pass = tensor[:,:,:-1].view((batch_size, channels, pos_len, pos_len)) + tensor_without_pass = tensor[:,:,:-1].view((batch_size, channels, pos_len_x, pos_len_y)) tensor_transformed = apply_symmetry(tensor_without_pass, symm) return torch.cat(( - tensor_transformed.reshape(batch_size, channels, pos_len*pos_len), + tensor_transformed.reshape(batch_size, channels, pos_len_x * pos_len_y), tensor[:,:,-1:] ), dim=2) diff --git a/python/katago/train/load_model.py b/python/katago/train/load_model.py index 6642c758f..66fd3cba9 100644 --- a/python/katago/train/load_model.py +++ b/python/katago/train/load_model.py @@ -9,6 +9,8 @@ from collections import defaultdict +from katago.train.model_pytorch import Game, parse_game + # defaultdict and float constructor are used in the ckpt for running metrics if packaging.version.parse(torch.__version__) > packaging.version.parse("2.4.0"): torch.serialization.add_safe_globals([defaultdict]) @@ -39,7 +41,7 @@ def load_swa_model_state_dict(state_dict): return swa_model_state_dict -def load_model(checkpoint_file, use_swa, device, pos_len=19, verbose=False): +def load_model(checkpoint_file, use_swa, device, pos_len_x=19, pos_len_y=19, verbose=False): from ..train.model_pytorch import Model from torch.optim.swa_utils import AveragedModel @@ -54,7 +56,10 @@ def load_model(checkpoint_file, use_swa, device, pos_len=19, verbose=False): model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config,pos_len) + effective_pos_len_x = state_dict.get("pos_len_x", pos_len_x) + effective_pos_len_y = state_dict.get("pos_len_y", pos_len_y) + games = [parse_game(s) for s in state_dict.get("games", [Game.GO.name])] + model = Model(model_config, effective_pos_len_x, effective_pos_len_y, games) model.initialize() # Strip off any "module." from when the model was saved with DDP or other things diff --git a/python/katago/train/metrics_pytorch.py b/python/katago/train/metrics_pytorch.py index eb938c554..9b52a4567 100644 --- a/python/katago/train/metrics_pytorch.py +++ b/python/katago/train/metrics_pytorch.py @@ -25,14 +25,15 @@ class Metrics: def __init__(self, batch_size: int, world_size: int, raw_model: Model): self.n = batch_size self.world_size = world_size - self.pos_len = raw_model.pos_len - self.pos_area = raw_model.pos_len * raw_model.pos_len - self.policy_len = raw_model.pos_len * raw_model.pos_len + 1 + self.pos_len_x = raw_model.pos_len_x + self.pos_len_y = raw_model.pos_len_y + self.pos_area = raw_model.pos_len_x * raw_model.pos_len_y + self.policy_len = raw_model.pos_len_x * raw_model.pos_len_y + 1 self.value_len = 3 self.num_td_values = 3 self.num_futurepos_values = 2 self.num_seki_logits = 4 - self.scorebelief_len = 2 * (self.pos_len*self.pos_len + EXTRA_SCORE_DISTR_RADIUS) + self.scorebelief_len = 2 * (self.pos_len_x * self.pos_len_y + EXTRA_SCORE_DISTR_RADIUS) self.scoremean_multiplier = raw_model.scoremean_multiplier @@ -120,9 +121,9 @@ def loss_ownership_samplewise(self, pred_pretanh, target, weight, mask, mask_sum # This uses a formulation where each batch element cares about its average loss. # In particular this means that ownership loss predictions on small boards "count more" per spot. # Not unlike the way that policy and value loss are also equal-weighted by batch element. - assert pred_pretanh.shape == (self.n, 1, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_pretanh.shape == (self.n, 1, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) pred_logits = pred_pretanh.view(self.n,self.pos_area) * 2.0 target_probs = (1.0 + target.view(self.n,self.pos_area)) / 2.0 @@ -134,9 +135,9 @@ def loss_ownership_samplewise(self, pred_pretanh, target, weight, mask, mask_sum def loss_scoring_samplewise(self, pred_scoring, target, weight, mask, mask_sum_hw, global_weight): - assert pred_scoring.shape == (self.n, 1, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_scoring.shape == (self.n, 1, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) loss = torch.sum(torch.square(pred_scoring.squeeze(1) - target) * mask, dim=(1,2)) / mask_sum_hw @@ -154,9 +155,9 @@ def loss_futurepos_samplewise(self, pred_pretanh, target, weight, mask, mask_sum # causing some scaling with board size. So, I dunno, let's compromise and scale by sqrt(boardarea). # Also, the further out targets should be weighted a little less due to them being higher entropy # due to simply being farther in the future, so multiply by [1,0.25]. - assert pred_pretanh.shape == (self.n, self.num_futurepos_values, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.num_futurepos_values, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_pretanh.shape == (self.n, self.num_futurepos_values, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.num_futurepos_values, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) loss = torch.square(torch.tanh(pred_pretanh) - target) * mask.unsqueeze(1) loss = loss * constant_like([1.0,0.25], loss).view(1,2,1,1) @@ -166,10 +167,10 @@ def loss_futurepos_samplewise(self, pred_pretanh, target, weight, mask, mask_sum def loss_seki_samplewise(self, pred_logits, target, target_ownership, weight, mask, mask_sum_hw, global_weight, is_training, skip_moving_update): assert self.num_seki_logits == 4 - assert pred_logits.shape == (self.n, self.num_seki_logits, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert target_ownership.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_logits.shape == (self.n, self.num_seki_logits, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert target_ownership.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) owned_target = torch.square(target_ownership) diff --git a/python/katago/train/model_pytorch.py b/python/katago/train/model_pytorch.py index f4e0a93d1..d233752ec 100644 --- a/python/katago/train/model_pytorch.py +++ b/python/katago/train/model_pytorch.py @@ -6,6 +6,8 @@ import torch.nn.init import packaging import packaging.version +from argparse import ArgumentTypeError +from enum import Enum from typing import List, Dict, Optional, Set from ..train import modelconfigs @@ -1384,7 +1386,7 @@ def forward(self, x, mask, mask_sum_hw, mask_sum:float, extra_outputs: Optional[ class ValueHead(torch.nn.Module): - def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation, pos_len): + def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation, pos_len_x, pos_len_y): super(ValueHead, self).__init__() self.activation = activation self.conv1 = torch.nn.Conv2d(c_in, c_v1, kernel_size=1, padding="same", bias=False) @@ -1407,8 +1409,9 @@ def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation self.conv_futurepos = torch.nn.Conv2d(c_in, 2, kernel_size=1, padding="same", bias=False) self.conv_seki = torch.nn.Conv2d(c_in, 4, kernel_size=1, padding="same", bias=False) - self.pos_len = pos_len - self.scorebelief_mid = self.pos_len*self.pos_len + EXTRA_SCORE_DISTR_RADIUS + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y + self.scorebelief_mid = self.pos_len_x * self.pos_len_y + EXTRA_SCORE_DISTR_RADIUS self.scorebelief_len = self.scorebelief_mid * 2 self.num_scorebeliefs = num_scorebeliefs self.c_sv2 = c_sv2 @@ -1602,8 +1605,21 @@ def forward(self, input_meta, extra_outputs: Optional[ExtraOutputs]): return self.out_scale * self.linear_output_to_trunk(x) +class Game(Enum): + GO = 0 + DOTS = 1 + +def parse_game(value: str) -> Game: + value = value.upper() + if value == "GO": + return Game.GO + elif value == "DOTS": + return Game.DOTS + else: + raise ArgumentTypeError(f"Game must be 'go' or 'dots', got '{value}'") + class Model(torch.nn.Module): - def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): + def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int, games=None): super(Model, self).__init__() self.config = config @@ -1620,7 +1636,9 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.c_sv2 = config["sbv2_num_channels"] self.num_scorebeliefs = config["num_scorebeliefs"] self.num_total_blocks = len(self.block_kind) - self.pos_len = pos_len + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y + self.games = games or [Game.GO] if config["version"] <= 12: self.td_score_multiplier = 20.0 @@ -1661,7 +1679,7 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): else: self.metadata_encoder = None - self.bin_input_shape = [22, pos_len, pos_len] + self.bin_input_shape = [22, pos_len_x, pos_len_y] self.global_input_shape = [19] self.blocks = torch.nn.ModuleList() @@ -1770,16 +1788,8 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.config, self.activation, ) - self.value_head = ValueHead( - self.c_trunk, - self.c_v1, - self.c_v2, - self.c_sv2, - self.num_scorebeliefs, - self.config, - self.activation, - self.pos_len, - ) + self.value_head = ValueHead(self.c_trunk, self.c_v1, self.c_v2, self.c_sv2, self.num_scorebeliefs, self.config, + self.activation, self.pos_len_x, self.pos_len_y) if self.has_intermediate_head: self.norm_intermediate_trunkfinal = NormMask(self.c_trunk, self.config, fixup_use_gamma=False, is_last_batchnorm=True) self.act_intermediate_trunkfinal = act(self.activation) @@ -1790,16 +1800,9 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.config, self.activation, ) - self.intermediate_value_head = ValueHead( - self.c_trunk, - self.c_v1, - self.c_v2, - self.c_sv2, - self.num_scorebeliefs, - self.config, - self.activation, - self.pos_len, - ) + self.intermediate_value_head = ValueHead(self.c_trunk, self.c_v1, self.c_v2, self.c_sv2, + self.num_scorebeliefs, self.config, self.activation, + self.pos_len_x, self.pos_len_y) @property def device(self): diff --git a/python/katago/train/modelconfigs.py b/python/katago/train/modelconfigs.py index 976be7178..e24efdb8c 100644 --- a/python/katago/train/modelconfigs.py +++ b/python/katago/train/modelconfigs.py @@ -46,14 +46,14 @@ def get_version(config: ModelConfig): def get_num_bin_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: # Dots game uses the same number of spatial features return 22 else: assert(False) def get_num_global_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: # Dots game uses the same number of global features return 19 else: assert(False) diff --git a/python/play.py b/python/play.py index 543542482..10e69ab56 100644 --- a/python/play.py +++ b/python/play.py @@ -58,13 +58,13 @@ np.set_printoptions(linewidth=150) torch.set_printoptions(precision=7,sci_mode=False,linewidth=100000,edgeitems=1000,threshold=1000000) -model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=True) +model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len, pos_len_y=pos_len, verbose=True) if swa_model is not None: model = swa_model model_config = model.config model.eval() -features = Features(model_config, pos_len) +features = Features(model_config, pos_len_x, pos_len_y) # Moves ---------------------------------------------------------------- @@ -413,9 +413,9 @@ def add_attention_visualizations(extra_output_name, extra_output): def get_board_matrix_str(matrix, scale, formatstr): ret = "" - matrix = matrix.reshape([features.pos_len,features.pos_len]) - for y in range(features.pos_len): - for x in range(features.pos_len): + matrix = matrix.reshape([features.pos_len_x, features.pos_len_y]) + for y in range(features.pos_len_y): + for x in range(features.pos_len_x): ret += formatstr % (scale * matrix[y,x]) ret += " " ret += "\n" @@ -456,7 +456,7 @@ def get_policy_matrix_str(matrix, gs, scale, formatstr): ret = '' if command[0] == "boardsize": - if int(command[1]) > features.pos_len: + if int(command[1]) > features.pos_len_x: print("Warning: Trying to set incompatible boardsize %s (!= %d)" % (command[1], N), file=sys.stderr) ret = None board_size = int(command[1]) diff --git a/python/selfplay/synchronous_loop.sh b/python/selfplay/synchronous_loop.sh index 255862aba..a894deab9 100755 --- a/python/selfplay/synchronous_loop.sh +++ b/python/selfplay/synchronous_loop.sh @@ -54,21 +54,21 @@ mkdir -p "$BASEDIR"/gatekeepersgf # you have strong hardware or are later into a run you may want to reduce the overhead by scaling # these numbers up and doing more games and training per cycle, exporting models less frequently, etc. -NUM_GAMES_PER_CYCLE=500 # Every cycle, play this many games -NUM_THREADS_FOR_SHUFFLING=8 -NUM_TRAIN_SAMPLES_PER_EPOCH=100000 # Training will proceed in chunks of this many rows, subject to MAX_TRAIN_PER_DATA. +NUM_GAMES_PER_CYCLE=400 # Every cycle, play this many games +NUM_THREADS_FOR_SHUFFLING=16 +NUM_TRAIN_SAMPLES_PER_EPOCH=50000 # Training will proceed in chunks of this many rows, subject to MAX_TRAIN_PER_DATA. MAX_TRAIN_PER_DATA=8 # On average, train only this many times on each data row. Larger numbers may cause overfitting. NUM_TRAIN_SAMPLES_PER_SWA=80000 # Stochastic weight averaging frequency. BATCHSIZE=128 # For lower-end GPUs 64 or smaller may be needed to avoid running out of GPU memory. -SHUFFLE_MINROWS=100000 # Require this many rows at the very start before beginning training. +SHUFFLE_MINROWS=50000 # Require this many rows at the very start before beginning training. MAX_TRAIN_SAMPLES_PER_CYCLE=500000 # Each cycle will do at most this many training steps. -TAPER_WINDOW_SCALE=50000 # Parameter setting the scale at which the shuffler will make the training window grow sublinearly. +TAPER_WINDOW_SCALE=25000 # Parameter setting the scale at which the shuffler will make the training window grow sublinearly. SHUFFLE_KEEPROWS=600000 # Needs to be larger than MAX_TRAIN_SAMPLES_PER_CYCLE, so the shuffler samples enough rows each cycle for the training to use. # Paths to the selfplay and gatekeeper configs that contain board sizes, rules, search parameters, etc. # See cpp/configs/training/README.md for some notes on other selfplay configs. -SELFPLAY_CONFIG="$GITROOTDIR"/cpp/configs/training/selfplay1.cfg -GATING_CONFIG="$GITROOTDIR"/cpp/configs/training/gatekeeper1.cfg +SELFPLAY_CONFIG="$GITROOTDIR"/cpp/configs/training/selfplay1_dots.cfg +GATING_CONFIG="$GITROOTDIR"/cpp/configs/training/gatekeeper1_dots.cfg # Copy all the relevant scripts and configs and the katago executable to a dated directory. # For archival and logging purposes - you can look back and see exactly the python code on a particular date diff --git a/python/selfplay/train.sh b/python/selfplay/train.sh index 40bf72279..79e92fcce 100755 --- a/python/selfplay/train.sh +++ b/python/selfplay/train.sh @@ -77,7 +77,8 @@ time python3 ./train.py \ -latestdatadir "$BASEDIR"/shuffleddata/ \ -exportdir "$BASEDIR"/"$EXPORT_SUBDIR" \ -exportprefix "$TRAININGNAME" \ - -pos-len 19 \ + -pos-len 39 \ + -games DOTS \ -batch-size "$BATCHSIZE" \ -model-kind "$MODELKIND" \ $EXTRAFLAG \ diff --git a/python/shuffle.py b/python/shuffle.py index 22f93a4d0..eda21e6c3 100755 --- a/python/shuffle.py +++ b/python/shuffle.py @@ -362,7 +362,11 @@ def get_numpy_npz_headers(filename): wasbad = True print("WARNING: bad file, skipping it: %s (bad array %s)" % (filename,subfilename)) else: - (shape, is_fortran, dtype) = np.lib.format._read_array_header(npyfile,version) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(npyfile) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(npyfile) + (shape, is_fortran, dtype) = header npzheaders[subfilename] = (shape, is_fortran, dtype) if wasbad: return None diff --git a/python/test.py b/python/test.py index b7e54837a..9a3a93ecf 100644 --- a/python/test.py +++ b/python/test.py @@ -17,7 +17,7 @@ from torch.optim.swa_utils import AveragedModel from katago.train import modelconfigs -from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder +from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder, Game, parse_game from katago.train.metrics_pytorch import Metrics from katago.train import data_processing_pytorch from katago.train.load_model import load_model @@ -36,12 +36,15 @@ parser.add_argument('-config', help='Path to model.config.json', required=False) parser.add_argument('-checkpoint', help='Checkpoint to test', required=False) parser.add_argument('-pos-len', help='Spatial length of expected training data', type=int, required=True) + parser.add_argument('-pos-len-x', help='Spatial width of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + parser.add_argument('-pos-len-y', help='Spatial height of expected training data. If undefined, `-pos-len` is used', type=int, required=False) parser.add_argument('-batch-size', help='Batch size to use for testing', type=int, required=True) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-max-batches', help='Maximum number of batches for testing', type=int, required=False) parser.add_argument('-gpu-idx', help='GPU idx', type=int, required=False) parser.add_argument('-print-norm', help='Names of outputs to print norms comma separated', type=str, required=False) parser.add_argument('-list-available-outputs', help='Print names of outputs available', action="store_true", required=False) + parser.add_argument('-games', help='Games to train: ' + ', '.join(e.name for e in Game) + ' (GO by default)', type=parse_game, nargs="+", required=False) args = vars(parser.parse_args()) @@ -51,6 +54,9 @@ def main(args): config_file = args["config"] checkpoint_file = args["checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len + games = args["games"] or [Game.GO] batch_size = args["batch_size"] use_swa = args["use_swa"] max_batches = args["max_batches"] @@ -109,11 +115,11 @@ def main(args): model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config,pos_len) + model = Model(model_config,pos_len_x,pos_len_y,games) model.initialize() model.to(device) else: - model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=True) + model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len_x, pos_len_y=pos_len_y, verbose=True) model_config = model.config metrics_obj = Metrics(batch_size,world_size,model) @@ -188,17 +194,11 @@ def log_metrics(prefix, metric_sums, metric_weights, metrics, sum_norms, num_sam num_samples_tested = 0 total_inference_time = 0.0 is_first_batch = True - for batch in data_processing_pytorch.read_npz_training_data( - val_files, - batch_size, - world_size, - rank, - pos_len, - device, - randomize_symmetries=True, - include_meta=model.get_has_metadata_encoder(), - model_config=model_config, - ): + for batch in data_processing_pytorch.read_npz_training_data(val_files, batch_size, world_size, rank, + pos_len_x, pos_len_y, + device, randomize_symmetries=True, + include_meta=model.get_has_metadata_encoder(), + model_config=model_config): if max_batches is not None and num_batches_tested >= max_batches: break diff --git a/python/train.py b/python/train.py index 3ee3f3c29..d612ffedf 100755 --- a/python/train.py +++ b/python/train.py @@ -31,7 +31,7 @@ from torch.cuda.amp import GradScaler, autocast from katago.train import modelconfigs -from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder +from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder, parse_game, Game from katago.train.metrics_pytorch import Metrics from katago.utils.push_back_generator import PushBackGenerator from katago.train import load_model @@ -64,8 +64,11 @@ optional_args.add_argument('-exportprefix', help='Prefix to append to names of models', required=False) optional_args.add_argument('-initial-checkpoint', help='If no training checkpoint exists, initialize from this checkpoint', required=False) - required_args.add_argument('-pos-len', help='Spatial edge length of expected training data, e.g. 19 for 19x19 Go', type=int, required=True) required_args.add_argument('-batch-size', help='Per-GPU batch size to use for training', type=int, required=True) + required_args.add_argument('-pos-len', help='Spatial edge length of expected training data, e.g. 19 for 19x19 Go', type=int, required=True) + optional_args.add_argument('-pos-len-x', help='Spatial width of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + optional_args.add_argument('-pos-len-y', help='Spatial height of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + optional_args.add_argument('-games', help='Games to train: ' + ", ".join(e.name for e in Game) + ' (GO by default)', type=parse_game, nargs="+", required=False) optional_args.add_argument('-samples-per-epoch', help='Number of data samples to consider as one epoch', type=int, required=False) optional_args.add_argument('-model-kind', help='String name for what model config to use', required=False) optional_args.add_argument('-lr-scale', help='LR multiplier on the hardcoded schedule', type=float, required=False) @@ -153,6 +156,9 @@ def main(rank: int, world_size: int, args, multi_gpu_device_ids, readpipes, writ initial_checkpoint = args["initial_checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len + games = args["games"] or [Game.GO] batch_size = args["batch_size"] samples_per_epoch = args["samples_per_epoch"] model_kind = args["model_kind"] @@ -310,6 +316,9 @@ def save(ddp_model, swa_model, optimizer, metrics_obj, running_metrics, train_st if rank == 0: state_dict = {} state_dict["model"] = ddp_model.state_dict() + state_dict["pos_len_x"] = getattr(ddp_model, "pos_len_x", 19) + state_dict["pos_len_y"] = getattr(ddp_model, "pos_len_y", 19) + state_dict["games"] = [g.name for g in getattr(ddp_model, "games", [Game.GO])] state_dict["optimizer"] = optimizer.state_dict() state_dict["metrics"] = metrics_obj.state_dict() state_dict["running_metrics"] = running_metrics @@ -443,7 +452,7 @@ def load(): assert model_kind is not None, "Model kind is none or unspecified but the model is being created fresh" model_config = modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len) + raw_model = Model(model_config,pos_len_x,pos_len_y,games) raw_model.initialize() raw_model.to(device) @@ -478,7 +487,7 @@ def load(): state_dict = torch.load(path_to_load_from, map_location=device) model_config = state_dict["config"] if "config" in state_dict else modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len) + raw_model = Model(model_config,pos_len_x,pos_len_y,games) raw_model.initialize() train_state = {} @@ -1045,17 +1054,11 @@ def detensorify_metrics(metrics): logging.info("This subepoch, using files: " + str(train_files_to_use)) logging.info("Currently up to data row " + str(train_state["total_num_data_rows"])) lookahead_counter = 0 - for batch in data_processing_pytorch.read_npz_training_data( - train_files_to_use, - batch_size, - world_size, - rank, - pos_len=pos_len, - device=device, - randomize_symmetries=True, - include_meta=raw_model.get_has_metadata_encoder(), - model_config=model_config - ): + for batch in data_processing_pytorch.read_npz_training_data(train_files_to_use, batch_size, world_size, + rank, pos_len_x, pos_len_y, + device=device, randomize_symmetries=True, + include_meta=raw_model.get_has_metadata_encoder(), + model_config=model_config): optimizer.zero_grad(set_to_none=True) extra_outputs = None # if raw_model.get_has_metadata_encoder(): @@ -1253,17 +1256,12 @@ def detensorify_metrics(metrics): val_metric_weights = defaultdict(float) val_samples = 0 t0 = time.perf_counter() - for batch in data_processing_pytorch.read_npz_training_data( - val_files, - batch_size, - world_size=1, # Only the main process validates - rank=0, # Only the main process validates - pos_len=pos_len, - device=device, - randomize_symmetries=True, - include_meta=raw_model.get_has_metadata_encoder(), - model_config=model_config - ): + for batch in data_processing_pytorch.read_npz_training_data(val_files, batch_size, world_size=1, rank=0, + pos_len_x=pos_len_x, pos_len_y=pos_len_y, + device=device, + randomize_symmetries=True, + include_meta=raw_model.get_has_metadata_encoder(), + model_config=model_config): model_outputs = ddp_model( batch["binaryInputNCHW"], batch["globalInputNC"],