From 498870f4041bf310e591d492d207c619aafb1a26 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 14 Nov 2025 00:20:56 +0100 Subject: [PATCH 01/42] [CI] Introduce GitHub CI that builds the project on Windows, Linux and macOS (OpenCL) Also, it runs `runtests` and publishes build artifacts fix #1124 --- .github/workflows/build.yml | 152 ++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 .github/workflows/build.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..115539b81 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,152 @@ +name: Build and Test + +on: + push: + branches: [ master ] + pull_request: + workflow_dispatch: + +jobs: + build-linux: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y cmake build-essential zlib1g-dev libzip-dev opencl-headers ocl-icd-opencl-dev + + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + key: ${{ runner.os }}-cmake-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake- + + - name: Configure CMake + working-directory: cpp + # Use '-DCMAKE_CXX_FLAGS_RELEASE="-s"' to strip debug symbols. Otherwise, the executable is too big + run: | + cmake . -DUSE_BACKEND=OPENCL -DNO_GIT_REVISION=1 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS_RELEASE="-s" + + - name: Build + working-directory: cpp + run: | + make -j$(nproc) + + - name: Run tests + run: | + cpp/katago runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-linux-opencl + path: cpp/katago + + build-macos: + runs-on: macos-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + brew install zlib libzip opencl-headers + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + cpp/build.ninja + cpp/.ninja_deps + cpp/.ninja_log + key: ${{ runner.os }}-cmake-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake- + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G Ninja -DUSE_BACKEND=OPENCL -DNO_GIT_REVISION=1 -DCMAKE_BUILD_TYPE=Release + - name: Build + working-directory: cpp + run: | + ninja + - name: Run tests + run: | + cpp/katago runtests + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-macos-opencl + path: cpp/katago + + build-windows: + runs-on: windows-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup MSVC + uses: microsoft/setup-msbuild@v2 + + - name: Cache vcpkg packages + uses: actions/cache@v4 + with: + path: | + ${{ env.VCPKG_INSTALLATION_ROOT }}/installed + ${{ env.VCPKG_INSTALLATION_ROOT }}/packages + key: ${{ runner.os }}-vcpkg-${{ hashFiles('**/vcpkg.json') }}-opencl + restore-keys: | + ${{ runner.os }}-vcpkg- + + - name: Install vcpkg dependencies + run: | + vcpkg install zlib:x64-windows libzip:x64-windows opencl:x64-windows + + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G "Visual Studio 17 2022" -A x64 ` + -DUSE_BACKEND=OPENCL ` + -DNO_GIT_REVISION=1 ` + -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" + + - name: Build + working-directory: cpp + run: | + cmake --build . --config Release -j 4 + + - name: Copy required DLLs + working-directory: cpp + run: | + $vcpkgRoot = $env:VCPKG_INSTALLATION_ROOT + Copy-Item "$vcpkgRoot/installed/x64-windows/bin/*.dll" -Destination "Release/" -ErrorAction SilentlyContinue + + - name: Run tests + run: | + cpp/Release/katago.exe runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-windows-opencl + path: cpp/Release/katago.exe From b551dd417f070a4f951abf33e4cbfa0f5b15cc45 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 14:58:40 +0100 Subject: [PATCH 02/42] First draft of Dots game implementation https://en.wikipedia.org/wiki/Dots_(game) --- cpp/CMakeLists.txt | 6 + cpp/book/book.cpp | 6 +- cpp/command/analysis.cpp | 4 +- cpp/command/benchmark.cpp | 2 +- cpp/command/evalsgf.cpp | 10 +- cpp/command/gtp.cpp | 37 +- cpp/command/runtests.cpp | 11 + cpp/command/selfplay.cpp | 4 +- cpp/command/startposes.cpp | 6 +- cpp/command/writetrainingdata.cpp | 6 +- cpp/core/config_parser.cpp | 8 + cpp/core/config_parser.h | 1 + cpp/dataio/sgf.cpp | 145 ++++-- cpp/dataio/sgf.h | 4 +- cpp/dataio/trainingwrite.cpp | 102 ++-- cpp/game/board.cpp | 404 ++++++++++----- cpp/game/board.h | 181 +++++-- cpp/game/boardhistory.cpp | 753 ++++++++++++++++----------- cpp/game/boardhistory.h | 10 +- cpp/game/common.h | 34 ++ cpp/game/dotsfield.cpp | 766 ++++++++++++++++++++++++++++ cpp/game/graphhash.cpp | 19 +- cpp/game/rules.cpp | 437 ++++++++++++---- cpp/game/rules.h | 101 +++- cpp/neuralnet/modelversion.cpp | 18 +- cpp/neuralnet/modelversion.h | 7 +- cpp/neuralnet/nneval.cpp | 22 +- cpp/neuralnet/nneval.h | 5 +- cpp/neuralnet/nninputs.cpp | 179 +++++-- cpp/neuralnet/nninputs.h | 57 +++ cpp/neuralnet/nninputsdots.cpp | 89 ++++ cpp/neuralnet/opencltuner.cpp | 2 +- cpp/program/play.cpp | 132 +++-- cpp/program/play.h | 4 + cpp/program/playutils.cpp | 6 +- cpp/program/setup.cpp | 6 +- cpp/search/asyncbot.cpp | 2 +- cpp/search/localpattern.cpp | 68 ++- cpp/search/localpattern.h | 9 + cpp/search/patternbonustable.cpp | 2 +- cpp/search/search.cpp | 30 +- cpp/search/search.h | 12 +- cpp/search/searchexplorehelpers.cpp | 1 - cpp/tests/testboardbasic.cpp | 67 ++- cpp/tests/testdotsbasic.cpp | 584 +++++++++++++++++++++ cpp/tests/testdotsextra.cpp | 424 +++++++++++++++ cpp/tests/testdotsstartposes.cpp | 263 ++++++++++ cpp/tests/testdotsutils.cpp | 39 ++ cpp/tests/testdotsutils.h | 60 +++ cpp/tests/testnninputs.cpp | 139 ++--- cpp/tests/testrules.cpp | 17 +- cpp/tests/tests.h | 11 + cpp/tests/testsearchcommon.cpp | 3 +- cpp/tests/testsgf.cpp | 104 +++- cpp/tests/testtrainingwrite.cpp | 3 +- 55 files changed, 4345 insertions(+), 1077 deletions(-) create mode 100644 cpp/game/common.h create mode 100644 cpp/game/dotsfield.cpp create mode 100644 cpp/neuralnet/nninputsdots.cpp create mode 100644 cpp/tests/testdotsbasic.cpp create mode 100644 cpp/tests/testdotsextra.cpp create mode 100644 cpp/tests/testdotsstartposes.cpp create mode 100644 cpp/tests/testdotsutils.cpp create mode 100644 cpp/tests/testdotsutils.h diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c523f28b0..bd2eb4520 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -231,6 +231,7 @@ add_executable(katago core/threadtest.cpp core/timer.cpp game/board.cpp + game/dotsfield.cpp game/rules.cpp game/boardhistory.cpp game/graphhash.cpp @@ -242,6 +243,7 @@ add_executable(katago dataio/homedata.cpp dataio/files.cpp neuralnet/nninputs.cpp + neuralnet/nninputsdots.cpp neuralnet/sgfmetadata.cpp neuralnet/modelversion.cpp neuralnet/nneval.cpp @@ -280,6 +282,10 @@ add_executable(katago ${GIT_HEADER_FILE_ALWAYS_UPDATED} tests/testboardarea.cpp tests/testboardbasic.cpp + tests/testdotsutils.cpp + tests/testdotsbasic.cpp + tests/testdotsstartposes.cpp + tests/testdotsextra.cpp tests/testbook.cpp tests/testcommon.cpp tests/testconfig.cpp diff --git a/cpp/book/book.cpp b/cpp/book/book.cpp index 32b30962f..72e3fd5c6 100644 --- a/cpp/book/book.cpp +++ b/cpp/book/book.cpp @@ -124,7 +124,7 @@ static Hash128 getExtraPosHash(const Board& board) { for(int y = 0; y()); int repBound = params["repBound"].get(); diff --git a/cpp/command/analysis.cpp b/cpp/command/analysis.cpp index e4cd37e92..696d28cad 100644 --- a/cpp/command/analysis.cpp +++ b/cpp/command/analysis.cpp @@ -882,14 +882,14 @@ int MainCmds::analysis(const vector& args) { if(input.find("rules") != input.end()) { if(input["rules"].is_string()) { string s = input["rules"].get(); - if(!Rules::tryParseRules(s,rules)) { + if(!Rules::tryParseRules(s, rules, input.value("dots", false))) { reportErrorForId(rbase.id, "rules", "Could not parse rules: " + s); continue; } } else if(input["rules"].is_object()) { string s = input["rules"].dump(); - if(!Rules::tryParseRules(s,rules)) { + if(!Rules::tryParseRules(s, rules, input.value("dots", false))) { reportErrorForId(rbase.id, "rules", "Could not parse rules: " + s); continue; } diff --git a/cpp/command/benchmark.cpp b/cpp/command/benchmark.cpp index 3100fb1b1..d39d8d940 100644 --- a/cpp/command/benchmark.cpp +++ b/cpp/command/benchmark.cpp @@ -643,7 +643,7 @@ int MainCmds::genconfig(const vector& args, const string& firstCommand) string prompt = "What rules should KataGo use by default for play and analysis?\n" "(chinese, japanese, korean, tromp-taylor, aga, chinese-ogs, new-zealand, bga, stone-scoring, aga-button):\n"; - promptAndParseInput(prompt, [&](const string& line) { configRules = Rules::parseRules(line); }); + promptAndParseInput(prompt, [&](const string& line) { configRules = Rules::parseRules(line, sgf->isDots); }); // TODO: probably incorrect for Dots game? } cout << endl; diff --git a/cpp/command/evalsgf.cpp b/cpp/command/evalsgf.cpp index e561d1545..29c534ba4 100644 --- a/cpp/command/evalsgf.cpp +++ b/cpp/command/evalsgf.cpp @@ -172,14 +172,14 @@ int MainCmds::evalsgf(const vector& args) { return 1; } - //Parse rules ------------------------------------------------------------------- - Rules defaultRules = Rules::getTrompTaylorish(); - Player perspective = Setup::parseReportAnalysisWinrates(cfg,P_BLACK); - //Parse sgf file and board ------------------------------------------------------------------ std::unique_ptr sgf = CompactSgf::loadFile(sgfFile); + //Parse rules ------------------------------------------------------------------- + Rules defaultRules = Rules::getDefaultOrTrompTaylorish(sgf->isDots); + Player perspective = Setup::parseReportAnalysisWinrates(cfg,P_BLACK); + Board board; Player nextPla; BoardHistory hist; @@ -219,7 +219,7 @@ int MainCmds::evalsgf(const vector& args) { [](const string& msg) { cout << msg << endl; } ); if(overrideRules != "") { - initialRules = Rules::parseRules(overrideRules); + initialRules = Rules::parseRules(overrideRules, initialRules.isDots); } // Set up once now for error catcihng diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 0420bf766..297fb2d89 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -198,7 +198,7 @@ static bool noWhiteStonesOnBoard(const Board& board) { for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { Loc loc = Location::getLoc(x,y,board.x_size); - if(board.colors[loc] == P_WHITE) + if(board.getColor(loc) == P_WHITE) return false; } } @@ -1328,7 +1328,7 @@ struct GTPEngine { for(int y = 0; y& args) { //Defaults to 7.5 komi, gtp will generally override this const bool loadKomiFromCfg = false; Rules initialRules = Setup::loadSingleRules(cfg,loadKomiFromCfg); - logger.write("Using " + initialRules.toStringNoKomiMaybeNice() + " rules initially, unless GTP/GUI overrides this"); + logger.write("Using " + initialRules.toStringNoSgfDefinedPropertiesMaybeNice() + " rules initially, unless GTP/GUI overrides this"); if(startupPrintMessageToStderr && !logger.isLoggingToStderr()) { - cerr << "Using " + initialRules.toStringNoKomiMaybeNice() + " rules initially, unless GTP/GUI overrides this" << endl; + cerr << "Using " + initialRules.toStringNoSgfDefinedPropertiesMaybeNice() + " rules initially, unless GTP/GUI overrides this" << endl; } bool isForcingKomi = false; float forcedKomi = 0; @@ -2362,7 +2362,8 @@ int MainCmds::gtp(const vector& args) { bool parseSuccess = false; Rules newRules; try { - newRules = Rules::parseRulesWithoutKomi(rest,engine->getCurrentRules().komi); + Rules currentRules = engine->getCurrentRules(); + newRules = Rules::parseRulesWithoutKomi(rest, currentRules.komi, currentRules.isDots); parseSuccess = true; } catch(const StringError& err) { @@ -2376,9 +2377,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } @@ -2406,9 +2407,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } } @@ -2423,19 +2424,19 @@ int MainCmds::gtp(const vector& args) { else { string s = Global::toLower(Global::trim(pieces[0])); if(s == "chinese") { - newRules = Rules::parseRulesWithoutKomi("chinese-kgs",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("chinese-kgs", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "aga") { - newRules = Rules::parseRulesWithoutKomi("aga",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("aga", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "new_zealand") { - newRules = Rules::parseRulesWithoutKomi("new_zealand",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("new_zealand", engine->getCurrentRules().komi); parseSuccess = true; } else if(s == "japanese") { - newRules = Rules::parseRulesWithoutKomi("japanese",engine->getCurrentRules().komi); + newRules = Rules::parseRulesWithoutKomi("japanese", engine->getCurrentRules().komi); parseSuccess = true; } else { @@ -2450,9 +2451,9 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = error; } - logger.write("Changed rules to " + newRules.toStringNoKomiMaybeNice()); + logger.write("Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice()); if(!logger.isLoggingToStderr()) - cerr << "Changed rules to " + newRules.toStringNoKomiMaybeNice() << endl; + cerr << "Changed rules to " + newRules.toStringNoSgfDefinedPropertiesMaybeNice() << endl; } } @@ -3243,7 +3244,7 @@ int MainCmds::gtp(const vector& args) { for(int y = 0; y& args) { for(int y = 0; y& args) { Board::initHash(); ScoreValue::initTables(); + Tests::runDotsFieldTests(); + Tests::runDotsGroundingTests(); + Tests::runDotsPosHashTests(); + Tests::runDotsStartPosTests(); + + Tests::runDotsStressTests(); + + Tests::runDotsSymmetryTests(); + Tests::runDotsTerritoryTests(); + Tests::runDotsCapturingTests(); + BSearch::runTests(); Rand::runTests(); DateTime::runTests(); diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index b03fb8c7d..cc8521148 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -93,10 +93,12 @@ int MainCmds::selfplay(const vector& args) { //Width and height of the board to use when writing data, typically 19 const int dataBoardLen = cfg.getInt("dataBoardLen",3,Board::MAX_LEN); + + const bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); const int inputsVersion = cfg.contains("inputsVersion") ? cfg.getInt("inputsVersion",0,10000) : - NNModelVersion::getInputsVersion(NNModelVersion::defaultModelVersion); + NNModelVersion::getInputsVersion(dotsGame ? NNModelVersion::defaultModelVersionForDots : NNModelVersion::defaultModelVersion); //Max number of games that we will allow to be queued up and not written out const int maxDataQueueSize = cfg.getInt("maxDataQueueSize",1,1000000); const int maxRowsPerTrainFile = cfg.getInt("maxRowsPerTrainFile",1,100000000); diff --git a/cpp/command/startposes.cpp b/cpp/command/startposes.cpp index d96cb61e8..6ee18b225 100644 --- a/cpp/command/startposes.cpp +++ b/cpp/command/startposes.cpp @@ -502,7 +502,7 @@ int MainCmds::samplesgfs(const vector& args) { //Only log on errors that aren't simply due to ko rules, but quit out regardless suc = hist.makeBoardMoveTolerant(board,sgfMoves[m].loc,sgfMoves[m].pla,preventEncore); if(!suc) - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size, board.isDots())); break; } hist.makeBoardMoveAssumeLegal(board,sgfMoves[m].loc,sgfMoves[m].pla,NULL,preventEncore); @@ -1469,7 +1469,7 @@ int MainCmds::dataminesgfs(const vector& args) { //Only log on errors that aren't simply due to ko rules, but quit out regardless suc = hist.makeBoardMoveTolerant(board,sgfMoves[m].loc,sgfMoves[m].pla,preventEncore); if(!suc) - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(sgfMoves[m].loc, board.x_size, board.y_size, board.isDots())); break; } hist.makeBoardMoveAssumeLegal(board,sgfMoves[m].loc,sgfMoves[m].pla,NULL,preventEncore); @@ -1662,7 +1662,7 @@ int MainCmds::dataminesgfs(const vector& args) { for(int i = 0; i& args) { if(dataBoardLen > Board::MAX_LEN) throw StringError("dataBoardLen > maximum board len, must recompile to increase"); - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); + static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); const int inputsVersion = 7; const int numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V7; const int numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V7; @@ -1878,7 +1878,7 @@ int MainCmds::writetrainingdata(const vector& args) { } bool suc = hist.isLegal(board,move.loc,move.pla); if(!suc) { - logger.write("Illegal move near start in " + fileName + " move " + Location::toString(move.loc, board.x_size, board.y_size) + sizeStr); + logger.write("Illegal move near start in " + fileName + " move " + Location::toString(move.loc, board) + sizeStr); reportSgfDone(false,"MovesIllegalMoveNearStart"); return; } @@ -1943,7 +1943,7 @@ int MainCmds::writetrainingdata(const vector& args) { } bool suc = hist.isLegal(board,move.loc,move.pla); if(!suc) { - logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(move.loc, board.x_size, board.y_size)); + logger.write("Illegal move in " + fileName + " turn " + Global::intToString(m) + " move " + Location::toString(move.loc, board)); reportSgfDone(false,"MovesIllegal"); return; } diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index e78388b63..e21d274b7 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -569,6 +569,13 @@ vector ConfigParser::getStrings(const string& key, const set& po return values; } +bool ConfigParser::getBoolOrDefault(const std::string& key, bool defaultValue) { + if (contains(key)) { + return getBool(key); + } + + return defaultValue; +} bool ConfigParser::getBool(const string& key) { string value = getString(key); @@ -577,6 +584,7 @@ bool ConfigParser::getBool(const string& key) { throw IOError("Could not parse '" + value + "' as bool for key '" + key + "' in config file " + fileName); return x; } + vector ConfigParser::getBools(const string& key) { vector values = getStrings(key); vector ret; diff --git a/cpp/core/config_parser.h b/cpp/core/config_parser.h index fa05ce5d2..8d0a9191a 100644 --- a/cpp/core/config_parser.h +++ b/cpp/core/config_parser.h @@ -58,6 +58,7 @@ class ConfigParser { std::string firstFoundOrEmpty(const std::vector& possibleKeys) const; std::string getString(const std::string& key); + bool getBoolOrDefault(const std::string& key, bool defaultValue); bool getBool(const std::string& key); enabled_t getEnabled(const std::string& key); int getInt(const std::string& key); diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 5000d3139..5039f184b 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -136,6 +136,25 @@ static void writeSgfLoc(ostream& out, Loc loc, int xSize, int ySize) { out << chars[y]; } +static Rules getRulesFromSgf(const bool dotsGame, const SgfNode& rootNode, const int xSize, const int ySize, const Rules* defaultRules) { + Rules rules; + if (defaultRules == nullptr || rootNode.hasProperty("RU")) { + rules = rootNode.getRulesFromRUTagOrFail(dotsGame); + } else { + rules = *defaultRules; + } + + if (defaultRules == nullptr || rootNode.hasProperty("KM")) { + rules.komi = rootNode.getKomiOrFail(); + } + + vector placementMoves; + rootNode.accumPlacements(placementMoves, xSize, ySize); + rules.startPos = Rules::tryRecognizeStartPos(xSize, ySize, placementMoves, true); + + return rules; +} + bool SgfNode::hasProperty(const char* key) const { if(props == nullptr) return false; @@ -273,13 +292,13 @@ Color SgfNode::getPLSpecifiedColor() const { return C_EMPTY; } -Rules SgfNode::getRulesFromRUTagOrFail() const { +Rules SgfNode::getRulesFromRUTagOrFail(const bool isDots) const { if(!hasProperty("RU")) throw StringError("SGF file does not specify rules"); string s = getSingleProperty("RU"); Rules parsed; - bool suc = Rules::tryParseRules(s,parsed); + bool suc = Rules::tryParseRules(s, parsed, isDots); if(!suc) throw StringError("Could not parse rules in sgf: " + s); return parsed; @@ -399,6 +418,13 @@ static void checkNonEmpty(const vector>& nodes) { throw StringError("Empty sgf"); } +bool Sgf::isDotsGame() const { + if(!nodes[0]->hasProperty("GM")) + return false; + const string& s = nodes[0]->getSingleProperty("GM"); + return s == "40"; +} + XYSize Sgf::getXYSize() const { checkNonEmpty(nodes); int xSize = 0; //Initialize to 0 to suppress spurious clang compiler warning. @@ -515,9 +541,9 @@ bool Sgf::hasRules() const { Rules Sgf::getRulesOrFail() const { checkNonEmpty(nodes); - Rules rules = nodes[0]->getRulesFromRUTagOrFail(); - rules.komi = getKomiOrFail(); - return rules; + + const XYSize size = getXYSize(); + return getRulesFromSgf(isDotsGame(), *nodes[0], size.x, size.y, nullptr); } Player Sgf::getSgfWinner() const { @@ -755,7 +781,7 @@ void Sgf::loadAllUniquePositions( Rand* rand, vector& samples ) const { - std::function f = [&samples](PositionSample& sample, const BoardHistory& hist, const string& comments) { + std::function f = [&samples](PositionSample& sample, const BoardHistory& hist, const string& comments) { (void)hist; (void)comments; samples.push_back(sample); @@ -773,17 +799,20 @@ void Sgf::iterAllUniquePositions( Rand* rand, std::function f ) const { + bool isDots = isDotsGame(); XYSize size = getXYSize(); int xSize = size.x; int ySize = size.y; - Board board(xSize,ySize); + Rules rules = Rules::getDefaultOrTrompTaylorish(isDots); Player nextPla = nodes.size() > 0 ? nodes[0]->getPLSpecifiedColor() : C_EMPTY; if(nextPla == C_EMPTY) nextPla = C_BLACK; - Rules rules = Rules::getTrompTaylorish(); - rules.koRule = Rules::KO_SITUATIONAL; - rules.multiStoneSuicideLegal = true; + if (!isDots) { + rules.koRule = Rules::KO_SITUATIONAL; + rules.multiStoneSuicideLegal = true; + } + Board board(xSize,ySize,rules); BoardHistory hist(board,nextPla,rules,0); PositionSample sampleBuf; @@ -800,17 +829,20 @@ void Sgf::iterAllPositions( Rand* rand, std::function f ) const { + bool isDots = isDotsGame(); XYSize size = getXYSize(); int xSize = size.x; int ySize = size.y; - Board board(xSize,ySize); Player nextPla = nodes.size() > 0 ? nodes[0]->getPLSpecifiedColor() : C_EMPTY; if(nextPla == C_EMPTY) nextPla = C_BLACK; - Rules rules = Rules::getTrompTaylorish(); - rules.koRule = Rules::KO_SITUATIONAL; - rules.multiStoneSuicideLegal = true; + Rules rules = Rules::getDefaultOrTrompTaylorish(isDots); + if (!isDots) { + rules.koRule = Rules::KO_SITUATIONAL; + rules.multiStoneSuicideLegal = true; + } + Board board(xSize,ySize,rules); BoardHistory hist(board,nextPla,rules,0); PositionSample sampleBuf; @@ -860,9 +892,9 @@ void Sgf::iterAllPositionsHelper( int netStonesAdded = 0; if(buf.size() > 0) { for(size_t j = 0; j 0x3FFFFFFF) @@ -1089,6 +1121,9 @@ std::set Sgf::readExcludes(const vector& files) { string Sgf::PositionSample::toJsonLine(const Sgf::PositionSample& sample) { json data; + if (sample.board.rules.isDots) { + data[DOTS_KEY] = "true"; + } data["xSize"] = sample.board.x_size; data["ySize"] = sample.board.y_size; data["board"] = Board::toStringSimple(sample.board,'/'); @@ -1116,9 +1151,10 @@ Sgf::PositionSample Sgf::PositionSample::ofJsonLine(const string& s) { json data = json::parse(s); PositionSample sample; try { + bool isDots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - sample.board = Board::parseBoard(xSize,ySize,data["board"].get(),'/'); + sample.board = Board::parseBoard(xSize, ySize, data["board"].get(), '/', Rules(isDots)); sample.nextPla = PlayerIO::parsePlayer(data["nextPla"].get()); vector moveLocs = data["moveLocs"].get>(); vector movePlas = data["movePlas"].get>(); @@ -1160,7 +1196,7 @@ Sgf::PositionSample Sgf::PositionSample::ofJsonLine(const string& s) { Sgf::PositionSample Sgf::PositionSample::getColorFlipped() const { Sgf::PositionSample other = *this; - Board newBoard(other.board.x_size,other.board.y_size); + Board newBoard(other.board.x_size, other.board.y_size, other.board.rules); for(int y = 0; y < other.board.y_size; y++) { for(int x = 0; x < other.board.x_size; x++) { Loc loc = Location::getLoc(x,y,other.board.x_size); @@ -1215,7 +1251,9 @@ int64_t Sgf::PositionSample::getCurrentTurnNumber() const { } bool Sgf::PositionSample::isEqualForTesting(const Sgf::PositionSample& other, bool checkNumCaptures, bool checkSimpleKo) const { - if(!board.isEqualForTesting(other.board,checkNumCaptures,checkSimpleKo)) + // Skips the rules check because default `koRule` value `KO_POSITIONAL` differs in `Rules` constructor and in `Sgf::iterAllPositions` (`KO_SITUATIONAL`) + // TODO: fix it + if(!board.isEqualForTesting(other.board,checkNumCaptures,checkSimpleKo,false)) return false; if(nextPla != other.nextPla) return false; @@ -1583,6 +1621,7 @@ std::vector> Sgf::loadSgfOrSgfsLogAndIgnoreErrors(const str CompactSgf::CompactSgf(const Sgf& sgf) :fileName(sgf.fileName), rootNode(), + isDots(), placements(), moves(), xSize(), @@ -1591,6 +1630,7 @@ CompactSgf::CompactSgf(const Sgf& sgf) sgfWinner(), hash(sgf.hash) { + isDots = sgf.isDotsGame(); XYSize size = sgf.getXYSize(); xSize = size.x; ySize = size.y; @@ -1616,6 +1656,7 @@ CompactSgf::CompactSgf(Sgf&& sgf) sgfWinner(), hash(sgf.hash) { + isDots = sgf.isDotsGame(); XYSize size = sgf.getXYSize(); xSize = size.x; ySize = size.y; @@ -1666,21 +1707,11 @@ bool CompactSgf::hasRules() const { } Rules CompactSgf::getRulesOrFail() const { - Rules rules = rootNode.getRulesFromRUTagOrFail(); - rules.komi = rootNode.getKomiOrFail(); - return rules; + return getRulesFromSgf(isDots, rootNode, xSize, ySize, nullptr); } Rules CompactSgf::getRulesOrFailAllowUnspecified(const Rules& defaultRules) const { - Rules rules; - if(!hasRules()) - rules = defaultRules; - else - rules = rootNode.getRulesFromRUTagOrFail(); - - if(rootNode.hasProperty("KM")) - rules.komi = rootNode.getKomiOrFail(); - return rules; + return getRulesFromSgf(isDots, rootNode, xSize, ySize, &defaultRules); } Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function f) const { @@ -1705,7 +1736,7 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function placementMoves; + rootNode.accumPlacements(placementMoves, xSize, ySize); + rules.startPos = Rules::tryRecognizeStartPos(xSize, ySize, placementMoves, true); + return rules; } @@ -1754,13 +1790,15 @@ void CompactSgf::setupInitialBoardAndHist(const Rules& initialRules, Board& boar if(moves.size() > 0) nextPla = moves[0].pla; - board = Board(xSize,ySize); - bool suc = board.setStonesFailIfNoLibs(placements); - if(!suc) - throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); + board = Board(xSize,ySize,initialRules); + if (initialRules.startPos == Rules::START_POS_EMPTY) { + bool suc = board.setStonesFailIfNoLibs(placements); + if(!suc) + throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); + } hist = BoardHistory(board,nextPla,initialRules,0); - if(hist.initialTurnNumber < board.numStonesOnBoard()) - hist.initialTurnNumber = board.numStonesOnBoard(); + if (int numStonesOnBoard = board.numStonesOnBoard(); hist.initialTurnNumber < numStonesOnBoard) + hist.initialTurnNumber = numStonesOnBoard; } void CompactSgf::playMovesAssumeLegal(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const { @@ -1790,7 +1828,7 @@ void CompactSgf::playMovesTolerant(Board& board, Player& nextPla, BoardHistory& for(int64_t i = 0; ihandicapForSgf << "]"; - } - else { - BoardHistory histCopy(endHist); - //Always use true for computing the handicap value that goes into an sgf - histCopy.setAssumeMultipleStartingBlackMovesAreHandicap(true); - out << "HA[" << histCopy.computeNumHandicapStones() << "]"; + if (!initialBoard.rules.isDots) { + if(gameData != NULL) { + out << "HA[" << gameData->handicapForSgf << "]"; + } + else { + BoardHistory histCopy(endHist); + //Always use true for computing the handicap value that goes into an sgf + histCopy.setAssumeMultipleStartingBlackMovesAreHandicap(true); + out << "HA[" << histCopy.computeNumHandicapStones() << "]"; + } } out << "KM[" << rules.komi << "]"; - out << "RU[" << (tryNicerRulesString ? rules.toStringNoKomiMaybeNice() : rules.toStringNoKomi()) << "]"; + out << "RU[" << (tryNicerRulesString ? rules.toStringNoSgfDefinedPropertiesMaybeNice() : rules.toStringNoSgfDefinedProps()) << "]"; printGameResult(out,endHist,overrideFinishedWhiteScore); bool hasAB = false; for(int y = 0; y& moves, int xSize, int ySize) const; Color getPLSpecifiedColor() const; - Rules getRulesFromRUTagOrFail() const; + Rules getRulesFromRUTagOrFail(bool isDots) const; Player getSgfWinner() const; float getKomiOrFail() const; float getKomiOrDefault(float defaultKomi) const; @@ -69,6 +69,7 @@ struct Sgf { static std::vector> loadSgfOrSgfsLogAndIgnoreErrors(const std::string& file, Logger& logger); + bool isDotsGame() const; XYSize getXYSize() const; float getKomiOrFail() const; float getKomiOrDefault(float defaultKomi) const; @@ -216,6 +217,7 @@ struct Sgf { struct CompactSgf { std::string fileName; + bool isDots; SgfNode rootNode; std::vector placements; std::vector moves; diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index a8b03bd25..34d9e778c 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -469,8 +469,8 @@ void TrainingWriteBuffers::addRow( SGFMetadata* sgfMeta, Rand& rand ) { - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion < 3 || inputsVersion > 7) + static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); + if(inputsVersion < 3 || inputsVersion > 8) throw StringError("Training write buffers: Does not support input version: " + Global::intToString(inputsVersion)); int posArea = dataXLen*dataYLen; @@ -490,36 +490,13 @@ void TrainingWriteBuffers::addRow( bool inputsUseNHWC = false; float* rowBin = binaryInputNCHWUnpacked; float* rowGlobal = globalInputNC.data + curRows * numGlobalChannels; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V3 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V3 == numGlobalChannels); - NNInputs::fillRowV3(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 4) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V4 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V4 == numGlobalChannels); - NNInputs::fillRowV4(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 5) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V5 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V5 == numGlobalChannels); - NNInputs::fillRowV5(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 6) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V6 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V6 == numGlobalChannels); - NNInputs::fillRowV6(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else if(inputsVersion == 7) { - assert(NNInputs::NUM_FEATURES_SPATIAL_V7 == numBinaryChannels); - assert(NNInputs::NUM_FEATURES_GLOBAL_V7 == numGlobalChannels); - NNInputs::fillRowV7(board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); - } - else - ASSERT_UNREACHABLE; - //Pack bools bitwise into uint8_t + assert(NNInputs::getNumberOfSpatialFeatures(inputsVersion) == numBinaryChannels); + assert(NNInputs::getNumberOfGlobalFeatures(inputsVersion) == numGlobalChannels); + + NNInputs::fillRowVN(inputsVersion, board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); + + // Pack bools bitwise into uint8_t uint8_t* rowBinPacked = binaryInputNCHWPacked.data + curRows * numBinaryChannels * packedBoardArea; for(int c = 0; c #include -#include -#include +#include +#include #include #include "../core/rand.h" @@ -45,16 +45,27 @@ int Location::getY(Loc loc, int x_size) { return (loc / (x_size+1)) - 1; } -void Location::getAdjacentOffsets(short adj_offsets[8], int x_size) -{ - adj_offsets[0] = -(x_size+1); - adj_offsets[1] = -1; - adj_offsets[2] = 1; - adj_offsets[3] = (x_size+1); - adj_offsets[4] = -(x_size+1)-1; - adj_offsets[5] = -(x_size+1)+1; - adj_offsets[6] = (x_size+1)-1; - adj_offsets[7] = (x_size+1)+1; +void Location::getAdjacentOffsets(short adj_offsets[8], int x_size, bool isDots) { + int stride = x_size + 1; + if (isDots) { + adj_offsets[LEFT_TOP_INDEX] = -stride - 1; + adj_offsets[TOP_INDEX] = -stride; + adj_offsets[RIGHT_TOP_INDEX] = -stride + 1; + adj_offsets[RIGHT_INDEX] = +1; + adj_offsets[RIGHT_BOTTOM_INDEX] = +stride + 1; + adj_offsets[BOTTOM_INDEX] = +stride; + adj_offsets[LEFT_BOTTOM_INDEX] = +stride - 1; + adj_offsets[LEFT_INDEX] = -1; + } else { + adj_offsets[0] = -stride; + adj_offsets[1] = -1; + adj_offsets[2] = 1; + adj_offsets[3] = stride; + adj_offsets[4] = -stride-1; + adj_offsets[5] = -stride+1; + adj_offsets[6] = stride-1; + adj_offsets[7] = stride+1; + } } bool Location::isAdjacent(Loc loc0, Loc loc1, int x_size) @@ -90,46 +101,61 @@ bool Location::isNearCentral(Loc loc, int x_size, int y_size) { return x >= (x_size-1)/2-1 && x <= x_size/2+1 && y >= (y_size-1)/2-1 && y <= y_size/2+1; } - -#define FOREACHADJ(BLOCK) {int ADJOFFSET = -(x_size+1); {BLOCK}; ADJOFFSET = -1; {BLOCK}; ADJOFFSET = 1; {BLOCK}; ADJOFFSET = x_size+1; {BLOCK}}; #define ADJ0 (-(x_size+1)) #define ADJ1 (-1) #define ADJ2 (1) #define ADJ3 (x_size+1) -//CONSTRUCTORS AND INITIALIZATION---------------------------------------------------------- +// CONSTRUCTORS AND INITIALIZATION---------------------------------------------------------- + +Board::Base::Base(Player newPla, + const std::vector& rollbackLocations, + const std::vector& rollbackStates, + bool isReal +) { + pla = newPla; + rollback_locations = rollbackLocations; + rollback_states = rollbackStates; + is_real = isReal; +} Board::Board() { - init(DEFAULT_LEN,DEFAULT_LEN); + init(DEFAULT_LEN, DEFAULT_LEN, Rules()); } Board::Board(int x, int y) { - init(x,y); + init(x, y, Rules()); } +Board::Board(int x, int y, const Rules& rules) { + init(x, y, rules); +} -Board::Board(const Board& other) -{ +Board::Board(const Board& other) { x_size = other.x_size; y_size = other.y_size; + rules = other.rules; memcpy(colors, other.colors, sizeof(Color)*MAX_ARR_SIZE); - memcpy(chain_data, other.chain_data, sizeof(ChainData)*MAX_ARR_SIZE); - memcpy(chain_head, other.chain_head, sizeof(Loc)*MAX_ARR_SIZE); - memcpy(next_in_chain, other.next_in_chain, sizeof(Loc)*MAX_ARR_SIZE); - ko_loc = other.ko_loc; + + if (!other.rules.isDots) { + chain_data = other.chain_data; + chain_head = other.chain_head; + next_in_chain = other.next_in_chain; + } + // empty_list = other.empty_list; pos_hash = other.pos_hash; numBlackCaptures = other.numBlackCaptures; numWhiteCaptures = other.numWhiteCaptures; - + numLegalMoves = other.numLegalMoves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); } -void Board::init(int xS, int yS) +void Board::init(const int xS, const int yS, const Rules& initRules) { assert(IS_ZOBRIST_INITALIZED); if(xS < 0 || yS < 0 || xS > MAX_LEN || yS > MAX_LEN) @@ -137,6 +163,7 @@ void Board::init(int xS, int yS) x_size = xS; y_size = yS; + rules = initRules; for(int i = 0; i < MAX_ARR_SIZE; i++) colors[i] = C_WALL; @@ -155,8 +182,20 @@ void Board::init(int xS, int yS) pos_hash = ZOBRIST_SIZE_X_HASH[x_size] ^ ZOBRIST_SIZE_Y_HASH[y_size]; numBlackCaptures = 0; numWhiteCaptures = 0; + numLegalMoves = xS * yS; - Location::getAdjacentOffsets(adj_offsets,x_size); + if (!rules.isDots) { + chain_data.resize(MAX_ARR_SIZE); + chain_head.resize(MAX_ARR_SIZE); + next_in_chain.resize(MAX_ARR_SIZE); + } + + Location::getAdjacentOffsets(adj_offsets, x_size, isDots()); + + const vector placement = Rules::generateStartPos(rules.startPos, x_size, y_size); + for (const Move& move : placement) { + playMoveAssumeLegal(move.loc, move.pla); + } } void Board::initHash() @@ -237,10 +276,13 @@ Hash128 Board::getSitHashWithSimpleKo(Player pla) const { void Board::clearSimpleKoLoc() { ko_loc = NULL_LOC; } -void Board::setSimpleKoLoc(Loc loc) { +void Board::setSimpleKoLoc(const Loc loc) { ko_loc = loc; } +Color Board::getColor(const Loc loc) const { + return static_cast(colors[loc] & ACTIVE_MASK); +} double Board::sqrtBoardArea() const { if (x_size == y_size) { @@ -269,6 +311,10 @@ bool Board::isSuicide(Loc loc, Player pla) const if(loc == PASS_LOC) return false; + if (rules.isDots) { + return isSuicideDots(loc, pla); + } + Player opp = getOpp(pla); FOREACHADJ( Loc adj = loc + ADJOFFSET; @@ -293,6 +339,10 @@ bool Board::isSuicide(Loc loc, Player pla) const //Check if moving here is would be an illegal self-capture bool Board::isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const { + if (rules.isDots) { + return !isMultiStoneSuicideLegal && isSuicideDots(loc, pla); + } + Player opp = getOpp(pla); FOREACHADJ( Loc adj = loc + ADJOFFSET; @@ -317,6 +367,7 @@ bool Board::isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) //Returns a fast lower bound on the number of liberties a new stone placed here would have void Board::getBoundNumLibertiesAfterPlay(Loc loc, Player pla, int& lowerBound, int& upperBound) const { + assert(!isDots()); Player opp = getOpp(pla); int numImmediateLibs = 0; //empty spaces adjacent @@ -354,6 +405,7 @@ void Board::getBoundNumLibertiesAfterPlay(Loc loc, Player pla, int& lowerBound, //Returns the number of liberties a new stone placed here would have, or max if it would be >= max. int Board::getNumLibertiesAfterPlay(Loc loc, Player pla, int max) const { + assert(!isDots()); Player opp = getOpp(pla); int numLibs = 0; @@ -445,32 +497,19 @@ bool Board::isKoBanned(Loc loc) const } bool Board::isOnBoard(Loc loc) const { - return loc >= 0 && loc < MAX_ARR_SIZE && colors[loc] != C_WALL; + return loc >= 0 && loc < MAX_ARR_SIZE && getColor(loc) != C_WALL; } //Check if moving here is illegal. -bool Board::isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const +bool Board::isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal, const bool ignoreKo) const { if(pla != P_BLACK && pla != P_WHITE) return false; return loc == PASS_LOC || ( loc >= 0 && loc < MAX_ARR_SIZE && - (colors[loc] == C_EMPTY) && - !isKoBanned(loc) && - !isIllegalSuicide(loc, pla, isMultiStoneSuicideLegal) - ); -} - -//Check if moving here is illegal, ignoring simple ko -bool Board::isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const -{ - if(pla != P_BLACK && pla != P_WHITE) - return false; - return loc == PASS_LOC || ( - loc >= 0 && - loc < MAX_ARR_SIZE && - (colors[loc] == C_EMPTY) && + getColor(loc) == C_EMPTY && + (rules.isDots || ignoreKo || !isKoBanned(loc)) && !isIllegalSuicide(loc, pla, isMultiStoneSuicideLegal) ); } @@ -478,6 +517,8 @@ bool Board::isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal //Check if this location contains a simple eye for the specified player. bool Board::isSimpleEye(Loc loc, Player pla) const { + assert(!rules.isDots); + if(colors[loc] != C_EMPTY) return false; @@ -509,9 +550,14 @@ bool Board::isSimpleEye(Loc loc, Player pla) const return true; } -bool Board::wouldBeCapture(Loc loc, Player pla) const { - if(colors[loc] != C_EMPTY) +bool Board::wouldBeCapture(const Loc loc, const Player pla) const { + if(getColor(loc) != C_EMPTY) return false; + + if (rules.isDots) { + return wouldBeCaptureDots(loc, pla); + } + Player opp = getOpp(pla); FOREACHADJ( Loc adj = loc + ADJOFFSET; @@ -525,8 +571,11 @@ bool Board::wouldBeCapture(Loc loc, Player pla) const { return false; } - bool Board::wouldBeKoCapture(Loc loc, Player pla) const { + if (isDots()) { + return false; // Ko is not relevant for Dots + } + if(colors[loc] != C_EMPTY) return false; //Check that surounding points are are all opponent owned and exactly one of them is capturable @@ -581,16 +630,16 @@ Loc Board::getKoCaptureLoc(Loc loc, Player pla) const { bool Board::isAdjacentToPla(Loc loc, Player pla) const { FOREACHADJ( Loc adj = loc + ADJOFFSET; - if(colors[adj] == pla) + if(getColor(adj) == pla) return true; ); return false; } bool Board::isAdjacentOrDiagonalToPla(Loc loc, Player pla) const { - for(int i = 0; i<8; i++) { + for(int i = 0; i < 8; i++) { Loc adj = loc + adj_offsets[i]; - if(colors[adj] == pla) + if(getColor(adj) == pla) return true; } return false; @@ -651,9 +700,11 @@ int Board::numStonesOnBoard() const { int num = 0; for(int y = 0; y < y_size; y++) { for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] == C_BLACK || colors[loc] == C_WHITE) + const Loc loc = Location::getLoc(x,y,x_size); + if(const Color color = rules.isDots ? getPlacedDotColor(getState(loc)) : colors[loc]; + color == C_BLACK || color == C_WHITE) { num += 1; + } } } return num; @@ -663,9 +714,10 @@ int Board::numPlaStonesOnBoard(Player pla) const { int num = 0; for(int y = 0; y < y_size; y++) { for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] == pla) + const Loc loc = Location::getLoc(x,y,x_size); + if(const Color color = rules.isDots ? getPlacedDotColor(getState(loc)) : colors[loc]; color == pla) { num += 1; + } } } return num; @@ -697,15 +749,16 @@ bool Board::setStone(Loc loc, Color color) } bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { - if(loc < 0 || loc >= MAX_ARR_SIZE || colors[loc] == C_WALL) + Color colorAtLoc = getColor(loc); + if(loc < 0 || loc >= MAX_ARR_SIZE || colorAtLoc == C_WALL) return false; if(color != C_BLACK && color != C_WHITE && color != C_EMPTY) return false; Loc oldKoLoc = ko_loc; - if(colors[loc] == color) + if(colorAtLoc == color) {} - else if(colors[loc] == C_EMPTY) { + else if(colorAtLoc == C_EMPTY) { if(isSuicide(loc,color) || wouldBeCapture(loc,color)) return false; playMoveAssumeLegal(loc,color); @@ -713,7 +766,7 @@ bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { else if(color == C_EMPTY) removeSingleStone(loc); else { - assert(colors[loc] == getOpp(color)); + assert(colorAtLoc == getOpp(color)); removeSingleStone(loc); if(isSuicide(loc,color) || wouldBeCapture(loc,color)) { playMoveAssumeLegal(loc,getOpp(color)); @@ -753,7 +806,7 @@ bool Board::setStonesFailIfNoLibs(std::vector placements) { //Attempts to play the specified move. Returns true if successful, returns false if the move was illegal. bool Board::playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal) { - if(isLegal(loc,pla,isMultiStoneSuicideLegal)) + if(isLegal(loc, pla, isMultiStoneSuicideLegal, false)) { playMoveAssumeLegal(loc,pla); return true; @@ -762,43 +815,50 @@ bool Board::playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal) } //Plays the specified move, assuming it is legal, and returns a MoveRecord for the move -Board::MoveRecord Board::playMoveRecorded(Loc loc, Player pla) -{ - MoveRecord record; - record.loc = loc; - record.pla = pla; - record.ko_loc = ko_loc; - record.capDirs = 0; +Board::MoveRecord Board::playMoveRecorded(const Loc loc, const Player pla) { + if (rules.isDots) { + return playMoveAssumeLegalDots(loc, pla); + } + + uint8_t capDirs = 0; if(loc != PASS_LOC) { Player opp = getOpp(pla); { int adj = loc + ADJ0; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 0); } + capDirs |= (((uint8_t)1) << 0); } { int adj = loc + ADJ1; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 1); } + capDirs |= (((uint8_t)1) << 1); } { int adj = loc + ADJ2; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 2); } + capDirs |= (((uint8_t)1) << 2); } { int adj = loc + ADJ3; if(colors[adj] == opp && getNumLiberties(adj) == 1) - record.capDirs |= (((uint8_t)1) << 3); } + capDirs |= (((uint8_t)1) << 3); } - if(record.capDirs == 0 && isSuicide(loc,pla)) - record.capDirs = 0x10; + if(capDirs == 0 && isSuicide(loc,pla)) + capDirs = 0x10; } + const Loc rollback_ko_loc = ko_loc; + playMoveAssumeLegal(loc, pla); - return record; + + return MoveRecord(loc, pla, rollback_ko_loc, capDirs); } //Undo the move given by record. Moves MUST be undone in the order they were made. //Undos will NOT typically restore the precise representation in the board to the way it was. The heads of chains //might change, the order of the circular lists might change, etc. -void Board::undo(Board::MoveRecord record) +void Board::undo(MoveRecord& record) { + if (rules.isDots) { + undoDots(record); + return; + } + ko_loc = record.ko_loc; Loc loc = record.loc; @@ -998,8 +1058,12 @@ Hash128 Board::getPosHashAfterMove(Loc loc, Player pla) const { } //Plays the specified move, assuming it is legal. -void Board::playMoveAssumeLegal(Loc loc, Player pla) -{ +void Board::playMoveAssumeLegal(Loc loc, Player pla) { + if (rules.isDots) { + playMoveAssumeLegalDots(loc, pla); + return; + } + //Pass? if(loc == PASS_LOC) { @@ -1218,6 +1282,9 @@ int Board::removeChain(Loc loc) //Remove a single stone, even a stone part of a larger group. void Board::removeSingleStone(Loc loc) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } Player pla = colors[loc]; //Save the entire chain's stone locations @@ -1432,7 +1499,7 @@ int Location::euclideanDistanceSquared(Loc loc0, Loc loc1, int x_size) { return dx*dx + dy*dy; } -//TACTICAL STUFF-------------------------------------------------------------------- +// TACTICAL STUFF-------------------------------------------------------------------- //Helper, find liberties of group at loc. Fills in buf, returns the number of liberties. //bufStart is where to start checking to avoid duplicates. bufIdx is where to start actually writing. @@ -1529,6 +1596,10 @@ bool Board::hasLibertyGainingCaptures(Loc loc) const { } bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, vector& workingMoves) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } + if(loc < 0 || loc >= MAX_ARR_SIZE) return false; if(colors[loc] != C_BLACK && colors[loc] != C_WHITE) @@ -1553,12 +1624,12 @@ bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, //Attacker: A suicide move cannot reduce the defender's liberties //Defender: A suicide move cannot gain liberties bool isMultiStoneSuicideLegal = false; - if(isLegal(move0,opp,isMultiStoneSuicideLegal)) { + if(isLegal(move0, opp, isMultiStoneSuicideLegal, false)) { MoveRecord record = playMoveRecorded(move0,opp); move0Works = searchIsLadderCaptured(loc,true,buf); undo(record); } - if(isLegal(move1,opp,isMultiStoneSuicideLegal)) { + if(isLegal(move1, opp, isMultiStoneSuicideLegal, false)) { MoveRecord record = playMoveRecorded(move1,opp); move1Works = searchIsLadderCaptured(loc,true,buf); undo(record); @@ -1576,6 +1647,10 @@ bool Board::searchIsLadderCapturedAttackerFirst2Libs(Loc loc, vector& buf, } bool Board::searchIsLadderCaptured(Loc loc, bool defenderFirst, vector& buf) { + if (isDots()) { + assert(false && "Not yet implemented for Dots game"); + } + if(loc < 0 || loc >= MAX_ARR_SIZE) return false; if(colors[loc] != C_BLACK && colors[loc] != C_WHITE) @@ -1780,7 +1855,7 @@ bool Board::searchIsLadderCaptured(Loc loc, bool defenderFirst, vector& buf //Illegal move - treat it the same as a failed move, but don't return up a level so that we //loop again and just try the next move. bool isMultiStoneSuicideLegal = false; - if(!isLegal(move,p,isMultiStoneSuicideLegal)) { + if(!isLegal(move, p, isMultiStoneSuicideLegal, false)) { returnValue = isDefender; returnedFromDeeper = false; // if(print) cout << "illegal " << endl; @@ -1807,6 +1882,11 @@ void Board::calculateArea( bool unsafeBigTerritories, bool isMultiStoneSuicideLegal ) const { + if (rules.isDots) { + calculateGroundingWhiteScore(result); + return; + } + std::fill(result,result+MAX_ARR_SIZE,C_EMPTY); calculateAreaForPla(P_BLACK,safeBigTerritories,unsafeBigTerritories,isMultiStoneSuicideLegal,result); calculateAreaForPla(P_WHITE,safeBigTerritories,unsafeBigTerritories,isMultiStoneSuicideLegal,result); @@ -1830,6 +1910,8 @@ void Board::calculateIndependentLifeArea( bool keepStones, bool isMultiStoneSuicideLegal ) const { + assert(!isDots()); + //First, just compute basic area. Color basicArea[MAX_ARR_SIZE]; std::fill(result,result+MAX_ARR_SIZE,C_EMPTY); @@ -1886,6 +1968,7 @@ void Board::calculateAreaForPla( bool isMultiStoneSuicideLegal, Color* result ) const { + assert(!isDots()); Color opp = getOpp(pla); //https://senseis.xmp.net/?BensonsAlgorithm @@ -2319,23 +2402,26 @@ void Board::checkConsistency() const { Hash128 tmp_pos_hash = ZOBRIST_SIZE_X_HASH[x_size] ^ ZOBRIST_SIZE_Y_HASH[y_size]; int emptyCount = 0; for(Loc loc = 0; loc < MAX_ARR_SIZE; loc++) { + const Color color = getColor(loc); int x = Location::getX(loc,x_size); int y = Location::getY(loc,x_size); if(x < 0 || x >= x_size || y < 0 || y >= y_size) { - if(colors[loc] != C_WALL) + if(color != C_WALL) throw StringError(errLabel + "Non-WALL value outside of board legal area"); } else { - if(colors[loc] == C_BLACK || colors[loc] == C_WHITE) { - if(!chainLocChecked[loc]) - checkChainConsistency(loc); - // if(empty_list.contains(loc)) - // throw StringError(errLabel + "Empty list contains filled location"); + if(color == C_BLACK || color == C_WHITE) { + if (!rules.isDots) { + if(!chainLocChecked[loc]) + checkChainConsistency(loc); + // if(empty_list.contains(loc)) + // throw StringError(errLabel + "Empty list contains filled location"); + } - tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][colors[loc]]; + tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][color]; tmp_pos_hash ^= ZOBRIST_BOARD_HASH[loc][C_EMPTY]; } - else if(colors[loc] == C_EMPTY) { + else if(color== C_EMPTY) { // if(!empty_list.contains(loc)) // throw StringError(errLabel + "Empty list doesn't contain empty location"); emptyCount += 1; @@ -2360,23 +2446,29 @@ void Board::checkConsistency() const { // throw StringError(errLabel + "Empty list index for loc in index i is not i"); // } - if(ko_loc != NULL_LOC) { - int x = Location::getX(ko_loc,x_size); - int y = Location::getY(ko_loc,x_size); - if(x < 0 || x >= x_size || y < 0 || y >= y_size) - throw StringError(errLabel + "Invalid simple ko loc"); - if(getNumImmediateLiberties(ko_loc) != 0) - throw StringError(errLabel + "Simple ko loc has immediate liberties"); + if (!rules.isDots) { + if(ko_loc != NULL_LOC) { + int x = Location::getX(ko_loc,x_size); + int y = Location::getY(ko_loc,x_size); + if(x < 0 || x >= x_size || y < 0 || y >= y_size) + throw StringError(errLabel + "Invalid simple ko loc"); + if(getNumImmediateLiberties(ko_loc) != 0) + throw StringError(errLabel + "Simple ko loc has immediate liberties"); + } } short tmpAdjOffsets[8]; - Location::getAdjacentOffsets(tmpAdjOffsets,x_size); - for(int i = 0; i<8; i++) + Location::getAdjacentOffsets(tmpAdjOffsets, x_size, isDots()); + for(int i = 0; i < 8; i++) if(tmpAdjOffsets[i] != adj_offsets[i]) throw StringError(errLabel + "Corrupted adj_offsets array"); } bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const { + return isEqualForTesting(other, checkNumCaptures, checkSimpleKo, true); +} + +bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo, bool checkRules) const { checkConsistency(); other.checkConsistency(); if(x_size != other.x_size) @@ -2395,6 +2487,12 @@ bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool ch if(colors[i] != other.colors[i]) return false; } + if (numLegalMoves != other.numLegalMoves) { + return false; + } + if (checkRules && rules != other.rules) { + return false; + } //We don't require that the chain linked lists are in the same order. //Consistency check ensures that all the linked lists are consistent with colors array, which we checked. return true; @@ -2455,10 +2553,10 @@ Player PlayerIO::parsePlayer(const string& s) { return pla; } -string Location::toStringMach(Loc loc, int x_size) +string Location::toStringMach(Loc loc, int x_size, bool isDots) { if(loc == Board::PASS_LOC) - return string("pass"); + return isDots ? "ground" : "pass"; if(loc == Board::NULL_LOC) return string("null"); char buf[128]; @@ -2466,19 +2564,19 @@ string Location::toStringMach(Loc loc, int x_size) return string(buf); } -string Location::toString(Loc loc, int x_size, int y_size) +string Location::toString(Loc loc, int x_size, int y_size, bool isDots) { if(x_size > 25*25) - return toStringMach(loc,x_size); + return toStringMach(loc, x_size, isDots); if(loc == Board::PASS_LOC) - return string("pass"); + return isDots ? "ground" : "pass"; if(loc == Board::NULL_LOC) return string("null"); const char* xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; int x = getX(loc,x_size); int y = getY(loc,x_size); if(x >= x_size || x < 0 || y < 0 || y >= y_size) - return toStringMach(loc,x_size); + return toStringMach(loc, x_size, isDots); char buf[128]; if(x <= 24) @@ -2489,11 +2587,11 @@ string Location::toString(Loc loc, int x_size, int y_size) } string Location::toString(Loc loc, const Board& b) { - return toString(loc,b.x_size,b.y_size); + return toString(loc, b.x_size, b.y_size, b.rules.isDots); } string Location::toStringMach(Loc loc, const Board& b) { - return toStringMach(loc,b.x_size); + return toStringMach(loc, b.x_size, b.isDots()); } static bool tryParseLetterCoordinate(char c, int& x) { @@ -2514,7 +2612,8 @@ bool Location::tryOfString(const string& str, int x_size, int y_size, Loc& resul string s = Global::trim(str); if(s.length() < 2) return false; - if(Global::isEqualCaseInsensitive(s,string("pass")) || Global::isEqualCaseInsensitive(s,string("pss"))) { + if(Global::isEqualCaseInsensitive(s,string("pass")) || Global::isEqualCaseInsensitive(s,string("pss")) || + Global::isEqualCaseInsensitive(s,string("ground"))) { result = Board::PASS_LOC; return true; } @@ -2618,17 +2717,24 @@ void Board::printBoard(ostream& out, const Board& board, Loc markLoc, const vect if(hist != NULL) out << "MoveNum: " << hist->size() << " "; out << "HASH: " << board.pos_hash << "\n"; - bool showCoords = board.x_size <= 50 && board.y_size <= 50; + bool showCoords = board.isDots() || (board.x_size <= 50 && board.y_size <= 50); if(showCoords) { - const char* xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; - out << " "; - for(int x = 0; x < board.x_size; x++) { - if(x <= 24) { - out << " "; - out << xChar[x]; + if (board.isDots()) { + out << " "; + for(int x = 0; x < board.x_size; x++) { + out << std::left << std::setw(2) << (x + 1) << ' '; } - else { - out << "A" << xChar[x-25]; + } else { + auto xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; + out << " "; + for(int x = 0; x < board.x_size; x++) { + if(x <= 24) { + out << " "; + out << xChar[x]; + } + else { + out << "A" << xChar[x-25]; + } } } out << "\n"; @@ -2637,15 +2743,14 @@ void Board::printBoard(ostream& out, const Board& board, Loc markLoc, const vect for(int y = 0; y < board.y_size; y++) { if(showCoords) { - char buf[16]; - sprintf(buf,"%2d",board.y_size-y); - out << buf << ' '; + out << std::right << std::setw(2) << board.y_size-y << ' '; } for(int x = 0; x < board.x_size; x++) { - Loc loc = Location::getLoc(x,y,board.x_size); - char s = PlayerIO::colorToChar(board.colors[loc]); - if(board.colors[loc] == C_EMPTY && markLoc == loc) + Loc loc = Location::getLoc(x, y , board.x_size); + const Color color = board.getColor(loc); // TODO: probably it makes sense to implement debug printing for Dots game + char s = PlayerIO::colorToChar(color); + if(color == C_EMPTY && markLoc == loc) out << '@'; else out << s; @@ -2662,7 +2767,11 @@ void Board::printBoard(ostream& out, const Board& board, Loc markLoc, const vect } } - if(x < board.x_size-1 && !histMarked) + if (!histMarked && board.isDots()) { + out << ' '; + } + + if(x < board.x_size-1 && (!histMarked || board.isDots())) out << ' '; } out << "\n"; @@ -2681,7 +2790,7 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { Loc loc = Location::getLoc(x,y,board.x_size); - s += PlayerIO::colorToChar(board.colors[loc]); + s += PlayerIO::colorToChar(board.getColor(loc)); } s += lineDelimiter; } @@ -2689,11 +2798,19 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { } Board Board::parseBoard(int xSize, int ySize, const string& s) { - return parseBoard(xSize,ySize,s,'\n'); + return parseBoard(xSize, ySize, s, '\n', Rules()); } -Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimiter) { - Board board(xSize,ySize); +Board Board::parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter) { + return parseBoard(xSize, ySize, s, lineDelimiter, Rules()); +} + +Board Board::parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules) { + return parseBoard(xSize, ySize, s, '\n', rules); +} + +Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimiter, const Rules& rules) { + Board board(xSize,ySize,rules); vector lines = Global::split(Global::trim(s),lineDelimiter); //Throw away coordinate labels line if it exists @@ -2742,22 +2859,34 @@ Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimite return board; } +std::string Board::toString() const { + std::ostringstream oss; + printBoard(oss, *this, NULL_LOC, nullptr); + return oss.str(); +} + nlohmann::json Board::toJson(const Board& board) { nlohmann::json data; + if (board.rules.isDots) { + data[DOTS_KEY] = true; + } data["xSize"] = board.x_size; data["ySize"] = board.y_size; - data["stones"] = Board::toStringSimple(board,'|'); - data["koLoc"] = Location::toString(board.ko_loc,board); + data["stones"] = toStringSimple(board,'|'); + if (!board.isDots()) { + data["koLoc"] = Location::toString(board.ko_loc,board); + } data["numBlackCaptures"] = board.numBlackCaptures; data["numWhiteCaptures"] = board.numWhiteCaptures; return data; } Board Board::ofJson(const nlohmann::json& data) { + bool dots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - Board board = Board::parseBoard(xSize,ySize,data["stones"].get(),'|'); - board.setSimpleKoLoc(Location::ofStringAllowNull(data["koLoc"].get(),board)); + Board board = parseBoard(xSize, ySize, data["stones"].get(), '|', Rules(dots)); + board.setSimpleKoLoc(Location::ofStringAllowNull(data.value("koLoc", "null"),board)); board.numBlackCaptures = data["numBlackCaptures"].get(); board.numWhiteCaptures = data["numWhiteCaptures"].get(); return board; @@ -2808,6 +2937,10 @@ bool Board::simpleRepetitionBoundGt(Loc loc, int bound) const { if(loc == NULL_LOC || loc == PASS_LOC) return false; + if (rules.isDots) { + return false; // TODO: implement for Dots? + } + int count = 0; if(colors[loc] != C_EMPTY) { @@ -2840,3 +2973,14 @@ bool Board::simpleRepetitionBoundGt(Loc loc, int bound) const { return false; } + +Board::MoveRecord::MoveRecord(const Loc initLoc, const Player initPla, const Loc init_ko_loc, const uint8_t initCapDirs) { + loc = initLoc; + pla = initPla; + ko_loc = init_ko_loc; + capDirs = initCapDirs; + + previousState = C_EMPTY; + bases = {}; + emptyBaseInvalidateLocations = {}; +} \ No newline at end of file diff --git a/cpp/game/board.h b/cpp/game/board.h index 4fdb2a259..2c58a140a 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -7,34 +7,46 @@ #ifndef GAME_BOARD_H_ #define GAME_BOARD_H_ +#include + #include "../core/global.h" #include "../core/hash.h" #include "../external/nlohmann_json/json.hpp" +#include "rules.h" #ifndef COMPILE_MAX_BOARD_LEN -#define COMPILE_MAX_BOARD_LEN 19 +#define COMPILE_MAX_BOARD_LEN 39 #endif +#define FOREACHADJ(BLOCK) {int ADJOFFSET = -(x_size+1); {BLOCK}; ADJOFFSET = -1; {BLOCK}; ADJOFFSET = 1; {BLOCK}; ADJOFFSET = x_size+1; {BLOCK}}; + //TYPES AND CONSTANTS----------------------------------------------------------------- +static constexpr int LEFT_TOP_INDEX = 0; +static constexpr int TOP_INDEX = 1; +static constexpr int RIGHT_TOP_INDEX = 2; +static constexpr int RIGHT_INDEX = 3; +static constexpr int RIGHT_BOTTOM_INDEX = 4; +static constexpr int BOTTOM_INDEX = 5; +static constexpr int LEFT_BOTTOM_INDEX = 6; +static constexpr int LEFT_INDEX = 7; + struct Board; -//Player -typedef int8_t Player; -static constexpr Player P_BLACK = 1; -static constexpr Player P_WHITE = 2; +typedef int8_t State; -//Color of a point on the board -typedef int8_t Color; -static constexpr Color C_EMPTY = 0; -static constexpr Color C_BLACK = 1; -static constexpr Color C_WHITE = 2; -static constexpr Color C_WALL = 3; -static constexpr int NUM_BOARD_COLORS = 4; +static constexpr int PLAYER_BITS_COUNT = 2; +static constexpr State ACTIVE_MASK = (1 << PLAYER_BITS_COUNT) - 1; -static inline Color getOpp(Color c) +static Color getOpp(Color c) {return c ^ 3;} +Color getActiveColor(State state); + +Color getPlacedDotColor(State s); + +Color getEmptyTerritoryColor(State s); + //Conversions for players and colors namespace PlayerIO { char colorToChar(Color c); @@ -44,16 +56,13 @@ namespace PlayerIO { Player parsePlayer(const std::string& s); } -//Location of a point on the board -//(x,y) is represented as (x+1) + (y+1)*(x_size+1) -typedef short Loc; namespace Location { Loc getLoc(int x, int y, int x_size); int getX(Loc loc, int x_size); int getY(Loc loc, int x_size); - void getAdjacentOffsets(short adj_offsets[8], int x_size); + void getAdjacentOffsets(short adj_offsets[8], int x_size, bool isDots); bool isAdjacent(Loc loc0, Loc loc1, int x_size); Loc getMirrorLoc(Loc loc, int x_size, int y_size); Loc getCenterLoc(int x_size, int y_size); @@ -62,10 +71,12 @@ namespace Location bool isNearCentral(Loc loc, int x_size, int y_size); int distance(Loc loc0, Loc loc1, int x_size); int euclideanDistanceSquared(Loc loc0, Loc loc1, int x_size); + int getGetBigJumpInitialIndex(Loc loc0, Loc loc1, int x_size); + Loc getNextLocCW(Loc loc0, Loc loc1, int x_size); - std::string toString(Loc loc, int x_size, int y_size); + std::string toString(Loc loc, int x_size, int y_size, bool isDots); std::string toString(Loc loc, const Board& b); - std::string toStringMach(Loc loc, int x_size); + std::string toStringMach(Loc loc, int x_size, bool isDots); std::string toStringMach(Loc loc, const Board& b); bool tryOfString(const std::string& str, int x_size, int y_size, Loc& result); @@ -80,10 +91,16 @@ namespace Location Loc ofStringAllowNull(const std::string& str, const Board& b); std::vector parseSequence(const std::string& str, const Board& b); -} -//Simple structure for storing moves. Not used below, but this is a convenient place to define it. -STRUCT_NAMED_PAIR(Loc,loc,Player,pla,Move); + Loc xm1y(Loc loc); + Loc xm1ym1(Loc loc, int x_size); + Loc xym1(Loc loc, int x_size); + Loc xp1ym1(Loc loc, int x_size); + Loc xp1y(Loc loc); + Loc xp1yp1(Loc loc, int x_size); + Loc xyp1(Loc loc, int x_size); + Loc xm1yp1(Loc loc, int x_size); +} //Fast lightweight board designed for playouts and simulations, where speed is essential. //Simple ko rule only. @@ -105,7 +122,7 @@ struct Board //Location used to indicate an invalid spot on the board. static constexpr Loc NULL_LOC = 0; - //Location used to indicate a pass move is desired. + //Location used to indicate a pass or grounding (Dots game) move is desired. static constexpr Loc PASS_LOC = 1; //Zobrist Hashing------------------------------ @@ -147,23 +164,71 @@ struct Board /* int size_; */ /* }; */ + struct Base { + Player pla{}; + bool is_real{}; + std::vector rollback_locations; + std::vector rollback_states; + + Base() = default; + Base(Player newPla, const std::vector& rollbackLocations, const std::vector& rollbackStates, bool isReal); + }; + //Move data passed back when moves are made to allow for undos struct MoveRecord { Player pla; Loc loc; Loc ko_loc; uint8_t capDirs; //First 4 bits indicate directions of capture, fifth bit indicates suicide + + // Move data for Dots game + State previousState; + std::vector bases; + std::vector emptyBaseInvalidateLocations; + + MoveRecord() = default; + + // Constructor for Go game + MoveRecord( + Loc initLoc, + Player initPla, + Loc init_ko_loc, + uint8_t initCapDirs + ); + + // Constructor for Dots game + MoveRecord( + Loc initLoc, + Player initPla, + State initPreviousState, + const std::vector& initBases, + const std::vector& initEmptyBaseInvalidateLocations + ); }; //Constructors--------------------------------- Board(); //Create Board of size (DEFAULT_LEN,DEFAULT_LEN) - Board(int x, int y); //Create Board of size (x,y) + Board(int x, int y); // Create Board of size (x,y) + Board(int x, int y, const Rules& rules); Board(const Board& other); Board& operator=(const Board&) = default; //Functions------------------------------------ + [[nodiscard]] Color getColor(Loc loc) const; + [[nodiscard]] State getState(Loc loc) const; + void setState(Loc loc, State state); + bool isDots() const; + + template void forEachAdjacent(const Loc loc, Func&& f) const { + const int stride = x_size + 1; + f(loc - stride); + f(loc - 1); + f(loc + 1); + f(loc + stride); + } + double sqrtBoardArea() const; //Gets the number of stones of the chain at loc. Precondition: location must be black or white. @@ -183,10 +248,8 @@ struct Board bool isIllegalSuicide(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; //Check if moving here is illegal due to simple ko bool isKoBanned(Loc loc) const; - //Check if moving here is legal, ignoring simple ko - bool isLegalIgnoringKo(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; //Check if moving here is legal. Equivalent to isLegalIgnoringKo && !isKoBanned - bool isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal) const; + bool isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal, bool ignoreKo) const; //Check if this location is on the board bool isOnBoard(Loc loc) const; //Check if this location contains a simple eye for the specified player. @@ -239,13 +302,13 @@ struct Board //Plays the specified move, assuming it is legal. void playMoveAssumeLegal(Loc loc, Player pla); - //Plays the specified move, assuming it is legal, and returns a MoveRecord for the move + // Plays the specified move, assuming it is legal, and returns a MoveRecord for the move MoveRecord playMoveRecorded(Loc loc, Player pla); //Undo the move given by record. Moves MUST be undone in the order they were made. //Undos will NOT typically restore the precise representation in the board to the way it was. The heads of chains //might change, the order of the circular lists might change, etc. - void undo(MoveRecord record); + void undo(MoveRecord& record); //Get what the position hash would be if we were to play this move and resolve captures and suicides. //Assumes the move is on an empty location. @@ -270,6 +333,7 @@ struct Board //If unsafeBigTerritories, also marks for each pla empty regions bordered by pla stones and no opp stones, regardless. //All other points are marked as C_EMPTY. //[result] must be a buffer of size MAX_ARR_SIZE and will get filled with the result + // For Dots game it just calculates grounding void calculateArea( Color* result, bool nonPassAliveStones, @@ -278,8 +342,9 @@ struct Board bool isMultiStoneSuicideLegal ) const; + int calculateGroundingWhiteScore(Color* result) const; - //Calculates the area (including non pass alive stones, safe and unsafe big territories) + // Calculates the area (including non pass alive stones, safe and unsafe big territories) //However, strips out any "seki" regions. //Seki regions are that are adjacent to any remaining empty regions. //If keepTerritories, then keeps the surrounded territories in seki regions, only strips points for stones. @@ -293,14 +358,20 @@ struct Board bool isMultiStoneSuicideLegal ) const; + void calculateOneMoveCaptureAndBasePositionsForDots(bool isSuicideLegal, std::vector& captures, std::vector& bases) const; + //Run some basic sanity checks on the board state, throws an exception if not consistent, for testing/debugging void checkConsistency() const; //For the moment, only used in testing since it does extra consistency checks. //If we need a version to be used in "prod", we could make an efficient version maybe as operator==. bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const; + bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo, bool checkRules) const; static Board parseBoard(int xSize, int ySize, const std::string& s); static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter); + static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules); + static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter, const Rules& rules); + std::string toString() const; static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); static std::string toStringSimple(const Board& board, char lineDelimiter); static nlohmann::json toJson(const Board& board); @@ -310,13 +381,9 @@ struct Board int x_size; //Horizontal size of board int y_size; //Vertical size of board + Rules rules; Color colors[MAX_ARR_SIZE]; //Color of each location on the board. - //Every chain of stones has one of its stones arbitrarily designated as the head. - ChainData chain_data[MAX_ARR_SIZE]; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. - Loc chain_head[MAX_ARR_SIZE]; //Where is the head of this chain? Undefined if EMPTY or WALL - Loc next_in_chain[MAX_ARR_SIZE]; //Location of next stone in chain. Circular linked list. Undefined if EMPTY or WALL - Loc ko_loc; //A simple ko capture was made here, making it illegal to replay here next move /* PointList empty_list; //List of all empty locations on board */ @@ -326,10 +393,50 @@ struct Board int numBlackCaptures; //Number of b stones captured, informational and used by board history when clearing pos int numWhiteCaptures; //Number of w stones captured, informational and used by board history when clearing pos - short adj_offsets[8]; //Indices 0-3: Offsets to add for adjacent points. Indices 4-7: Offsets for diagonal points. 2 and 3 are +x and +y. + // Offsets to add to get clockwise traverse + short adj_offsets[8]; + + int numLegalMoves; + + //Every chain of stones has one of its stones arbitrarily designated as the head. + std::vector chain_data; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. + std::vector chain_head; //Where is the head of this chain? Undefined if EMPTY or WALL + std::vector next_in_chain; //Location of next stone in chain. Circular linked list. Undefined if EMPTY or WALL private: - void init(int xS, int yS); + + // Dots game data + mutable std::array unconnectedLocationsBuffer = std::array(); + mutable int unconnectedLocationsBufferSize = 0; + mutable std::vector closureOrInvalidateLocsBuffer = std::vector(); + mutable std::vector territoryLocationsBuffer = std::vector(); + mutable std::vector walkStack = std::vector(); + + // Dots game functions + [[nodiscard]] bool wouldBeCaptureDots(Loc loc, Player pla) const; + [[nodiscard]] bool isSuicideDots(Loc loc, Player pla) const; + MoveRecord playMoveAssumeLegalDots(Loc loc, Player pla); + MoveRecord tryPlayMove(Loc loc, Player pla, bool isSuicideLegal); + void undoDots(MoveRecord& moveRecord); + Base captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp); + std::vector tryCapture(Loc loc, Player pla, bool emptyBaseCapturing); + std::vector ground(Player pla, std::vector& emptyBaseInvalidatePositions); + void getUnconnectedLocations(Loc loc, Player pla) const; + void checkAndAddUnconnectedLocation(Player checkPla,Player currentPla,Loc addLoc1,Loc addLoc2) const; + void tryGetCounterClockwiseClosure(Loc initialLoc, Loc startLoc, Player pla); + Base buildBase(const std::vector& closure, Player pla); + void getTerritoryLocations(Player pla, Loc firstLoc, bool grounding, bool& createRealBase, bool& grounded); + Base createBaseAndUpdateStates(Player basePla, bool isReal); + void updateScoreAndHashForTerritory(Loc loc, State state, Player basePla, bool rollback); + void invalidateAdjacentEmptyTerritoryIfNeeded(Loc loc); + void makeMoveAndCalculateCapturesAndBases(Player pla, Loc loc, bool isSuicideLegal, + std::vector& captures, std::vector& bases) const; + void setVisited(Loc loc); + void clearVisited(Loc loc); + void clearVisited(const std::vector& locations); + int calculateGroundingWhiteScore(Player pla, std::unordered_set& nonGroundedLocs) const; + + void init(int xS, int yS, const Rules& initRules); int countHeuristicConnectionLibertiesX2(Loc loc, Player pla) const; bool isLibertyOf(Loc loc, Loc head) const; void mergeChains(Loc loc1, Loc loc2); diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 5b3c35a9b..24be1f053 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -55,10 +55,12 @@ BoardHistory::BoardHistory() isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), isScored(false),isNoResult(false),isResignation(false) { - std::fill(wasEverOccupiedOrPlayed, wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, false); - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); + if (!rules.isDots) { + wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); + superKoBanned.resize(Board::MAX_ARR_SIZE, false); + koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); + secondEncoreStartColors.resize(Board::MAX_ARR_SIZE, C_EMPTY); + } } BoardHistory::~BoardHistory() @@ -95,10 +97,12 @@ BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int e isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), isScored(false),isNoResult(false),isResignation(false) { - std::fill(wasEverOccupiedOrPlayed, wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, false); - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); + if (!rules.isDots) { + wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); + superKoBanned.resize(Board::MAX_ARR_SIZE, false); + koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); + secondEncoreStartColors.resize(Board::MAX_ARR_SIZE, C_EMPTY); + } clear(board,pla,rules,ePhase); } @@ -134,11 +138,11 @@ BoardHistory::BoardHistory(const BoardHistory& other) isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) { - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; + koRecapBlocked = other.koRecapBlocked; + secondEncoreStartColors = other.secondEncoreStartColors; } @@ -158,11 +162,11 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); + std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; consecutiveEndingPasses = other.consecutiveEndingPasses; hashesBeforeBlackPass = other.hashesBeforeBlackPass; hashesBeforeWhitePass = other.hashesBeforeWhitePass; @@ -170,10 +174,10 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) numTurnsThisPhase = other.numTurnsThisPhase; numApproxValidTurnsThisPhase = other.numApproxValidTurnsThisPhase; numConsecValidTurnsThisGame = other.numConsecValidTurnsThisGame; - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); + koRecapBlocked = other.koRecapBlocked; koRecapBlockHash = other.koRecapBlockHash; koCapturesInEncore = other.koCapturesInEncore; - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + secondEncoreStartColors = other.secondEncoreStartColors; whiteBonusScore = other.whiteBonusScore; whiteHandicapBonusScore = other.whiteHandicapBonusScore; hasButton = other.hasButton; @@ -219,11 +223,11 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) { - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; + koRecapBlocked = other.koRecapBlocked; + secondEncoreStartColors = other.secondEncoreStartColors; } BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept @@ -240,11 +244,11 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy(other.recentBoards, other.recentBoards+NUM_RECENT_BOARDS, recentBoards); + std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; - std::copy(other.wasEverOccupiedOrPlayed, other.wasEverOccupiedOrPlayed+Board::MAX_ARR_SIZE, wasEverOccupiedOrPlayed); - std::copy(other.superKoBanned, other.superKoBanned+Board::MAX_ARR_SIZE, superKoBanned); + wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; + superKoBanned = other.superKoBanned; consecutiveEndingPasses = other.consecutiveEndingPasses; hashesBeforeBlackPass = std::move(other.hashesBeforeBlackPass); hashesBeforeWhitePass = std::move(other.hashesBeforeWhitePass); @@ -252,10 +256,10 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept numTurnsThisPhase = other.numTurnsThisPhase; numApproxValidTurnsThisPhase = other.numApproxValidTurnsThisPhase; numConsecValidTurnsThisGame = other.numConsecValidTurnsThisGame; - std::copy(other.koRecapBlocked, other.koRecapBlocked+Board::MAX_ARR_SIZE, koRecapBlocked); + koRecapBlocked = other.koRecapBlocked; koRecapBlockHash = other.koRecapBlockHash; koCapturesInEncore = std::move(other.koCapturesInEncore); - std::copy(other.secondEncoreStartColors, other.secondEncoreStartColors+Board::MAX_ARR_SIZE, secondEncoreStartColors); + secondEncoreStartColors = other.secondEncoreStartColors; whiteBonusScore = other.whiteBonusScore; whiteHandicapBonusScore = other.whiteHandicapBonusScore; hasButton = other.hasButton; @@ -292,22 +296,12 @@ void BoardHistory::clear(const Board& board, Player pla, const Rules& r, int ePh currentRecentBoardIdx = 0; presumedNextMovePla = pla; - - for(int y = 0; y= 0 && encorePhase <= 2); - if(encorePhase > 0) - assert(rules.scoringRule == Rules::SCORING_TERRITORY); - //Update the few parameters that depend on encore - if(encorePhase == 2) - std::copy(board.colors, board.colors+Board::MAX_ARR_SIZE, secondEncoreStartColors); - else - std::fill(secondEncoreStartColors, secondEncoreStartColors+Board::MAX_ARR_SIZE, C_EMPTY); - - //Push hash for the new board state - koHashHistory.push_back(getKoHash(rules,board,pla,encorePhase,koRecapBlockHash)); - - if(rules.scoringRule == Rules::SCORING_TERRITORY) { - //Chill 1 point for every move played + if (!rules.isDots) { for(int y = 0; y= 0 && encorePhase <= 2); + if(encorePhase > 0) + assert(rules.scoringRule == Rules::SCORING_TERRITORY); + //Update the few parameters that depend on encore + if(encorePhase == 2) + std::copy_n(board.colors, Board::MAX_ARR_SIZE, secondEncoreStartColors.begin()); + else + std::fill(secondEncoreStartColors.begin(), secondEncoreStartColors.end(), C_EMPTY); + + //Push hash for the new board state + koHashHistory.push_back(getKoHash(rules,board,pla,encorePhase,koRecapBlockHash)); + + if(rules.scoringRule == Rules::SCORING_TERRITORY) { + //Chill 1 point for every move played + for(int y = 0; ysize() moves of koHashHistory. //ALSO counts the most recent ko hash! bool BoardHistory::koHashOccursInHistory(Hash128 koHash, const KoHashTable* rootKoHashTable) const { + assert(!rules.isDots); + size_t start = 0; size_t koHashHistorySize = koHashHistory.size(); if(rootKoHashTable != NULL && @@ -573,7 +600,13 @@ float BoardHistory::currentSelfKomi(Player pla, double drawEquivalentWinsForWhit } } +int BoardHistory::countGroundingScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) { + return board.calculateGroundingWhiteScore(area); +} + int BoardHistory::countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots == board.isDots() && !rules.isDots); + int score = 0; if(rules.taxRule == Rules::TAX_NONE) { bool nonPassAliveStones = true; @@ -615,6 +648,8 @@ int BoardHistory::countAreaScoreWhiteMinusBlack(const Board& board, Color area[B //ALSO makes area color the points that were not pass alive but were scored for a side. int BoardHistory::countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots == board.isDots() && !rules.isDots); + int score = 0; bool keepTerritories; bool keepStones; @@ -675,7 +710,9 @@ void BoardHistory::setFinalScoreAndWinner(float score) { } void BoardHistory::getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { - if(rules.scoringRule == Rules::SCORING_AREA) + if(rules.isDots) + countGroundingScoreWhiteMinusBlack(board,area); + else if(rules.scoringRule == Rules::SCORING_AREA) countAreaScoreWhiteMinusBlack(board,area); else if(rules.scoringRule == Rules::SCORING_TERRITORY) countTerritoryAreaScoreWhiteMinusBlack(board,area); @@ -684,8 +721,12 @@ void BoardHistory::getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE } void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) { - int boardScore; - if(rules.scoringRule == Rules::SCORING_AREA) + assert(rules.isDots == board.isDots()); + + int boardScore = 0; + if(rules.isDots) + boardScore = countGroundingScoreWhiteMinusBlack(board,area); + else if(rules.scoringRule == Rules::SCORING_AREA) boardScore = countAreaScoreWhiteMinusBlack(board,area); else if(rules.scoringRule == Rules::SCORING_TERRITORY) boardScore = countTerritoryAreaScoreWhiteMinusBlack(board,area); @@ -693,11 +734,12 @@ void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ ASSERT_UNREACHABLE; if(hasButton) { + assert(!board.rules.isDots); hasButton = false; whiteBonusScore += (presumedNextMovePla == P_WHITE ? 0.5f : -0.5f); } - setFinalScoreAndWinner(boardScore + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); isScored = true; isNoResult = false; isResignation = false; @@ -711,42 +753,75 @@ void BoardHistory::endAndScoreGameNow(const Board& board) { } void BoardHistory::endGameIfAllPassAlive(const Board& board) { - int boardScore = 0; - bool nonPassAliveStones = false; - bool safeBigTerritories = false; - bool unsafeBigTerritories = false; - Color area[Board::MAX_ARR_SIZE]; - board.calculateArea( - area, - nonPassAliveStones, safeBigTerritories, unsafeBigTerritories, rules.multiStoneSuicideLegal - ); + assert(rules.isDots == board.isDots()); + + if (rules.isDots) { + bool gameOver = false; + float normalizedWhiteScoreIfGroundingAlive = 0.0f; + + if (board.numLegalMoves == 0) { + // No legal locs to place a dot -> game is over. + gameOver = true; + normalizedWhiteScoreIfGroundingAlive = static_cast(board.numBlackCaptures - board.numWhiteCaptures) + rules.komi; + } else { + Board::MoveRecord moveRecord = const_cast(board).playMoveRecorded(Board::PASS_LOC, presumedNextMovePla); + const float whiteScoreAfterNextPlaGrounding = static_cast(board.numBlackCaptures - board.numWhiteCaptures) + rules.komi; + const_cast(board).undo(moveRecord); + + if (presumedNextMovePla == P_BLACK && whiteScoreAfterNextPlaGrounding < 0.0f || + presumedNextMovePla == P_WHITE && whiteScoreAfterNextPlaGrounding > 0.0f) { + gameOver = true; + normalizedWhiteScoreIfGroundingAlive = whiteScoreAfterNextPlaGrounding; + } + } - for(int y = 0; y 0) { if(moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { @@ -788,7 +871,7 @@ bool BoardHistory::isLegal(const Board& board, Loc moveLoc, Player movePla) cons if(board.isKoBanned(moveLoc)) return false; } - if(!board.isLegalIgnoringKo(moveLoc,movePla,rules.multiStoneSuicideLegal)) + if(!board.isLegal(moveLoc, movePla, rules.multiStoneSuicideLegal, true)) return false; if(superKoBanned[moveLoc]) return false; @@ -797,6 +880,9 @@ bool BoardHistory::isLegal(const Board& board, Loc moveLoc, Player movePla) cons } bool BoardHistory::isPassForKo(const Board& board, Loc moveLoc, Player movePla) const { + assert(rules.isDots == board.isDots()); + if (rules.isDots) return false; + if(encorePhase > 0 && moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc] && board.getChainSize(moveLoc) == 1 && board.getNumLiberties(moveLoc) == 1) return true; @@ -857,6 +943,10 @@ bool BoardHistory::wouldBeSpightlikeEndingPass(Player movePla, Hash128 koHashBef } bool BoardHistory::passWouldEndPhase(const Board& board, Player movePla) const { + assert(rules.isDots == board.isDots()); + if (rules.isDots) return false; + + // TODO: probably add the assert? assert(!board.isDots()); Hash128 koHashBeforeMove = getKoHash(rules, board, movePla, encorePhase, koRecapBlockHash); if(newConsecutiveEndingPassesAfterPass() >= 2 || wouldBeSpightlikeEndingPass(movePla,koHashBeforeMove)) @@ -865,6 +955,9 @@ bool BoardHistory::passWouldEndPhase(const Board& board, Player movePla) const { } bool BoardHistory::passWouldEndGame(const Board& board, Player movePla) const { + if (board.rules.isDots) { + return true; // Pass in Dots game is grounding move that always ends the game + } return passWouldEndPhase(board,movePla) && ( rules.scoringRule == Rules::SCORING_AREA || (rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase >= 2) @@ -880,7 +973,8 @@ bool BoardHistory::shouldSuppressEndGameFromFriendlyPass(const Board& board, Pla bool BoardHistory::isFinalPhase() const { return - rules.scoringRule == Rules::SCORING_AREA + rules.isDots + || rules.scoringRule == Rules::SCORING_AREA || (rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase >= 2); } @@ -888,8 +982,9 @@ bool BoardHistory::isLegalTolerant(const Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; return true; } @@ -897,8 +992,9 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; makeBoardMoveAssumeLegal(board,moveLoc,movePla,NULL); return true; @@ -907,8 +1003,9 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla == presumedNextMovePla && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide regardless of rules - if(!isPassForKo(board, moveLoc, movePla) && !board.isLegalIgnoringKo(moveLoc,movePla,multiStoneSuicideLegal)) + bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool ignoreKo = true; + if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; makeBoardMoveAssumeLegal(board,moveLoc,movePla,NULL,preventEncore); return true; @@ -929,7 +1026,7 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); } - bool moveIsIllegal = !isLegal(board,moveLoc,movePla); + const bool moveIsIllegal = !isLegal(board,moveLoc,movePla); //And if somehow we're making a move after the game was ended, just clear those values and continue. isGameFinished = false; @@ -942,87 +1039,102 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo //Update consecutiveEndingPasses and button bool isSpightlikeEndingPass = false; - if(moveLoc != Board::PASS_LOC) - consecutiveEndingPasses = 0; - else if(hasButton) { - assert(encorePhase == 0 && rules.hasButton); - hasButton = false; - whiteBonusScore += (movePla == P_WHITE ? 0.5f : -0.5f); - consecutiveEndingPasses = 0; - //Taking the button clears all ko hash histories (this is equivalent to not clearing them and treating buttonless - //state as different than buttonful state) - hashesBeforeBlackPass.clear(); - hashesBeforeWhitePass.clear(); - koHashHistory.clear(); - //The first turn idx with history will be the one RESULTING from this move. - firstTurnIdxWithKoHistory = moveHistory.size()+1; - } - else { - //Passes clear ko history in the main phase with spight ko rules and in the encore - //This lifts bans in spight ko rules and lifts 3-fold-repetition checking in the encore for no-resultifying infinite cycles - //They also clear in simple ko rules for the purpose of no-resulting long cycles. Long cycles with passes do not no-result. - if(phaseHasSpightlikeEndingAndPassHistoryClearing()) { + bool wasPassForKo = false; + + if (rules.isDots) { + //Dots game + board.playMoveAssumeLegal(moveLoc, movePla); + if (moveLoc == Board::PASS_LOC) { + isScored = true; + isNoResult = false; + isResignation = false; + isGameFinished = true; + isPastNormalPhaseEnd = false; + const auto whiteMinusBlackScore = static_cast(board.numBlackCaptures - board.numWhiteCaptures); + setFinalScoreAndWinner(whiteMinusBlackScore + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + } + } else { + if(moveLoc != Board::PASS_LOC) + consecutiveEndingPasses = 0; + else if(hasButton) { + assert(encorePhase == 0 && rules.hasButton); + hasButton = false; + whiteBonusScore += (movePla == P_WHITE ? 0.5f : -0.5f); + consecutiveEndingPasses = 0; + //Taking the button clears all ko hash histories (this is equivalent to not clearing them and treating buttonless + //state as different than buttonful state) + hashesBeforeBlackPass.clear(); + hashesBeforeWhitePass.clear(); koHashHistory.clear(); //The first turn idx with history will be the one RESULTING from this move. firstTurnIdxWithKoHistory = moveHistory.size()+1; - //Does not clear hashesBeforeBlackPass or hashesBeforeWhitePass. Passes lift ko bans, but - //still repeated positions after pass end the game or phase, which these arrays are used to check. } + else { + //Passes clear ko history in the main phase with spight ko rules and in the encore + //This lifts bans in spight ko rules and lifts 3-fold-repetition checking in the encore for no-resultifying infinite cycles + //They also clear in simple ko rules for the purpose of no-resulting long cycles. Long cycles with passes do not no-result. + if(phaseHasSpightlikeEndingAndPassHistoryClearing()) { + koHashHistory.clear(); + //The first turn idx with history will be the one RESULTING from this move. + firstTurnIdxWithKoHistory = moveHistory.size()+1; + //Does not clear hashesBeforeBlackPass or hashesBeforeWhitePass. Passes lift ko bans, but + //still repeated positions after pass end the game or phase, which these arrays are used to check. + } - Hash128 koHashBeforeThisMove = getKoHash(rules,board,movePla,encorePhase,koRecapBlockHash); - consecutiveEndingPasses = newConsecutiveEndingPassesAfterPass(); - //Check if we have a game-ending pass BEFORE updating hashesBeforeBlackPass and hashesBeforeWhitePass - isSpightlikeEndingPass = wouldBeSpightlikeEndingPass(movePla,koHashBeforeThisMove); + Hash128 koHashBeforeThisMove = getKoHash(rules,board,movePla,encorePhase,koRecapBlockHash); + consecutiveEndingPasses = newConsecutiveEndingPassesAfterPass(); + //Check if we have a game-ending pass BEFORE updating hashesBeforeBlackPass and hashesBeforeWhitePass + isSpightlikeEndingPass = wouldBeSpightlikeEndingPass(movePla,koHashBeforeThisMove); - //Update hashesBeforeBlackPass and hashesBeforeWhitePass - if(movePla == P_BLACK) - hashesBeforeBlackPass.push_back(koHashBeforeThisMove); - else if(movePla == P_WHITE) - hashesBeforeWhitePass.push_back(koHashBeforeThisMove); - else - ASSERT_UNREACHABLE; - } - - //Handle pass-for-ko moves in the encore. Pass for ko lifts a ko recapture block and does nothing else. - bool wasPassForKo = false; - if(encorePhase > 0 && moveLoc != Board::PASS_LOC) { - if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc]) { - setKoRecapBlocked(moveLoc,false); - wasPassForKo = true; - //Clear simple ko loc just in case - //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. - board.clearSimpleKoLoc(); + //Update hashesBeforeBlackPass and hashesBeforeWhitePass + if(movePla == P_BLACK) + hashesBeforeBlackPass.push_back(koHashBeforeThisMove); + else if(movePla == P_WHITE) + hashesBeforeWhitePass.push_back(koHashBeforeThisMove); + else + ASSERT_UNREACHABLE; } - else { - Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); - if(koCaptureLoc != Board::NULL_LOC && koRecapBlocked[koCaptureLoc] && board.colors[koCaptureLoc] == getOpp(movePla)) { - setKoRecapBlocked(koCaptureLoc,false); + + //Handle pass-for-ko moves in the encore. Pass for ko lifts a ko recapture block and does nothing else. + if(encorePhase > 0 && moveLoc != Board::PASS_LOC) { + if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc]) { + setKoRecapBlocked(moveLoc,false); wasPassForKo = true; //Clear simple ko loc just in case //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. board.clearSimpleKoLoc(); } - } - } - //Otherwise handle regular moves - if(!wasPassForKo) { - board.playMoveAssumeLegal(moveLoc,movePla); - - if(encorePhase > 0) { - //Update ko recapture blocks and record that this was a ko capture - if(board.ko_loc != Board::NULL_LOC) { - setKoRecapBlocked(moveLoc,true); - koCapturesInEncore.push_back(EncoreKoCapture(posHashBeforeMove,moveLoc,movePla)); - //Clear simple ko loc now that we've absorbed the ko loc information into the korecap blocks - //Once we have that, the simple ko loc plays no further role in game state or legality - board.clearSimpleKoLoc(); + else { + Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); + if(koCaptureLoc != Board::NULL_LOC && koRecapBlocked[koCaptureLoc] && board.colors[koCaptureLoc] == getOpp(movePla)) { + setKoRecapBlocked(koCaptureLoc,false); + wasPassForKo = true; + //Clear simple ko loc just in case + //Since we aren't otherwise touching the board, from the board's perspective a player will be moving twice in a row. + board.clearSimpleKoLoc(); + } } - //Unmark all ko recap blocks not on stones - for(int y = 0; y 0) { + //Update ko recapture blocks and record that this was a ko capture + if(board.ko_loc != Board::NULL_LOC) { + setKoRecapBlocked(moveLoc,true); + koCapturesInEncore.push_back(EncoreKoCapture(posHashBeforeMove,moveLoc,movePla)); + //Clear simple ko loc now that we've absorbed the ko loc information into the korecap blocks + //Once we have that, the simple ko loc plays no further role in game state or legality + board.clearSimpleKoLoc(); + } + //Unmark all ko recap blocks not on stones + for(int y = 0; y 0) { - //During the encore, only one capture of each ko in a given position by a given player - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - for(size_t i = 0; i 0) { + //During the encore, only one capture of each ko in a given position by a given player + std::fill(superKoBanned.begin(), superKoBanned.end(), false); + for(size_t i = 0; i= 2 || isSpightlikeEndingPass) { - if(rules.scoringRule == Rules::SCORING_AREA) { - assert(encorePhase <= 0); - endAndScoreGameNow(board); + //Handicap bonus score + if(movePla == P_WHITE && moveLoc != Board::PASS_LOC) + whiteHasMoved = true; + if(assumeMultipleStartingBlackMovesAreHandicap && !whiteHasMoved && movePla == P_BLACK && rules.whiteHandicapBonusRule != Rules::WHB_ZERO) { + whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); } - else if(rules.scoringRule == Rules::SCORING_TERRITORY) { - if(encorePhase >= 2) + + //Phase transitions and game end + if(consecutiveEndingPasses >= 2 || isSpightlikeEndingPass) { + if(rules.scoringRule == Rules::SCORING_AREA) { + assert(encorePhase <= 0); endAndScoreGameNow(board); - else { - if(preventEncore) { - isPastNormalPhaseEnd = true; - //Cap at 1 - do include just the single pass here by itself since the single pass by itself - //absent any history of passes before that should be valid still. - numApproxValidTurnsThisPhase = std::min(numApproxValidTurnsThisPhase,1); - numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); - } + } + else if(rules.scoringRule == Rules::SCORING_TERRITORY) { + if(encorePhase >= 2) + endAndScoreGameNow(board); else { - encorePhase += 1; - numTurnsThisPhase = 0; - numApproxValidTurnsThisPhase = 0; - if(encorePhase == 2) - std::copy(board.colors, board.colors+Board::MAX_ARR_SIZE, secondEncoreStartColors); - - std::fill(superKoBanned, superKoBanned+Board::MAX_ARR_SIZE, false); - consecutiveEndingPasses = 0; - hashesBeforeBlackPass.clear(); - hashesBeforeWhitePass.clear(); - std::fill(koRecapBlocked, koRecapBlocked+Board::MAX_ARR_SIZE, false); - koRecapBlockHash = Hash128(); - koCapturesInEncore.clear(); - - koHashHistory.clear(); - koHashHistory.push_back(getKoHash(rules,board,getOpp(movePla),encorePhase,koRecapBlockHash)); - //The first ko hash history is the one for the move we JUST appended to the move history earlier. - firstTurnIdxWithKoHistory = moveHistory.size(); + if(preventEncore) { + isPastNormalPhaseEnd = true; + //Cap at 1 - do include just the single pass here by itself since the single pass by itself + //absent any history of passes before that should be valid still. + numApproxValidTurnsThisPhase = std::min(numApproxValidTurnsThisPhase,1); + numConsecValidTurnsThisGame = std::min(numConsecValidTurnsThisGame,1); + } + else { + encorePhase += 1; + numTurnsThisPhase = 0; + numApproxValidTurnsThisPhase = 0; + if(encorePhase == 2) + std::copy_n(board.colors, Board::MAX_ARR_SIZE, secondEncoreStartColors.begin()); + + std::fill(superKoBanned.begin(), superKoBanned.end(), false); + consecutiveEndingPasses = 0; + hashesBeforeBlackPass.clear(); + hashesBeforeWhitePass.clear(); + std::fill(koRecapBlocked.begin(), koRecapBlocked.end(), false); + koRecapBlockHash = Hash128(); + koCapturesInEncore.clear(); + + koHashHistory.clear(); + koHashHistory.push_back(getKoHash(rules,board,getOpp(movePla),encorePhase,koRecapBlockHash)); + //The first ko hash history is the one for the move we JUST appended to the move history earlier. + firstTurnIdxWithKoHistory = moveHistory.size(); + } } } + else + ASSERT_UNREACHABLE; } - else - ASSERT_UNREACHABLE; - } - //Break long cycles with no-result - if(moveLoc != Board::PASS_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { - if(numberOfKoHashOccurrencesInHistory(koHashHistory[koHashHistory.size()-1], rootKoHashTable) >= 3) { - isNoResult = true; - isGameFinished = true; + //Break long cycles with no-result + if(moveLoc != Board::PASS_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { + if(numberOfKoHashOccurrencesInHistory(koHashHistory[koHashHistory.size()-1], rootKoHashTable) >= 3) { + isNoResult = true; + isGameFinished = true; + } } } - } @@ -1181,6 +1296,9 @@ Hash128 BoardHistory::getSituationAndSimpleKoHash(const Board& board, Player nex //Note that board.pos_hash also incorporates the size of the board. Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; + if (board.isDots()) { + assert(board.ko_loc == Board::NULL_LOC); + } if(board.ko_loc != Board::NULL_LOC) hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; return hash; @@ -1190,6 +1308,9 @@ Hash128 BoardHistory::getSituationAndSimpleKoAndPrevPosHash(const Board& board, //Note that board.pos_hash also incorporates the size of the board. Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; + if (board.isDots()) { + assert(board.ko_loc == Board::NULL_LOC); + } if(board.ko_loc != Board::NULL_LOC) hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; @@ -1209,37 +1330,39 @@ Hash128 BoardHistory::getSituationRulesAndKoHash(const Board& board, const Board Hash128 hash = board.pos_hash; hash ^= Board::ZOBRIST_PLAYER_HASH[nextPlayer]; - assert(hist.encorePhase >= 0 && hist.encorePhase <= 2); - hash ^= Board::ZOBRIST_ENCORE_HASH[hist.encorePhase]; - - if(hist.encorePhase == 0) { - if(board.ko_loc != Board::NULL_LOC) - hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; - for(int y = 0; y= 0 && hist.encorePhase <= 2); + hash ^= Board::ZOBRIST_ENCORE_HASH[hist.encorePhase]; + + if(hist.encorePhase == 0) { + if(board.ko_loc != Board::NULL_LOC) + hash ^= Board::ZOBRIST_KO_LOC_HASH[board.ko_loc]; + for(int y = 0; y wasEverOccupiedOrPlayed; //Locations where the next player is not allowed to play due to superko - bool superKoBanned[Board::MAX_ARR_SIZE]; + std::vector superKoBanned; //Number of consecutive passes made that count for ending the game or phase int consecutiveEndingPasses; @@ -68,7 +68,7 @@ struct BoardHistory { int numConsecValidTurnsThisGame; //Ko-recapture-block locations for territory scoring in encore - bool koRecapBlocked[Board::MAX_ARR_SIZE]; + std::vector koRecapBlocked; Hash128 koRecapBlockHash; //Hash contribution from ko-recap-block locations in encore. //Used to implement once-only rules for ko captures in encore @@ -76,7 +76,7 @@ struct BoardHistory { std::vector koCapturesInEncore; //State of the grid as of the start of encore phase 2 for territory scoring - Color secondEncoreStartColors[Board::MAX_ARR_SIZE]; + std::vector secondEncoreStartColors; //Amount that should be added to komi float whiteBonusScore; @@ -168,6 +168,7 @@ struct BoardHistory { bool isLegalTolerant(const Board& board, Loc moveLoc, Player movePla) const; //Slightly expensive, check if the entire game is all pass-alive-territory, and if so, declare the game finished + // For Dots game it's Grounding alive void endGameIfAllPassAlive(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); @@ -200,6 +201,7 @@ struct BoardHistory { private: bool koHashOccursInHistory(Hash128 koHash, const KoHashTable* rootKoHashTable) const; void setKoRecapBlocked(Loc loc, bool b); + static int countGroundingScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]); int countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; int countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; void setFinalScoreAndWinner(float score); diff --git a/cpp/game/common.h b/cpp/game/common.h new file mode 100644 index 000000000..18cc632f2 --- /dev/null +++ b/cpp/game/common.h @@ -0,0 +1,34 @@ +#ifndef GAME_COMMON_H +#define GAME_COMMON_H +#include +#include "../core/global.h" + +const std::string DOTS_KEY = "dots"; +const std::string DOTS_CAPTURE_EMPTY_BASE_KEY = "dotsCaptureEmptyBase"; +const std::string DOTS_CAPTURE_EMPTY_BASES_KEY = "dotsCaptureEmptyBases"; +const std::string START_POS_KEY = "startPos"; +const std::string START_POSES_KEY = "startPoses"; + +// Player +typedef int8_t Player; +static constexpr Player P_BLACK = 1; +static constexpr Player P_WHITE = 2; + +//Color of a point on the board +typedef int8_t Color; +static constexpr Color C_EMPTY = 0; +static constexpr Color C_BLACK = 1; +static constexpr Color C_WHITE = 2; +static constexpr Color C_WALL = 3; +static constexpr int NUM_BOARD_COLORS = 4; + +typedef int8_t State; + +//Location of a point on the board +//(x,y) is represented as (x+1) + (y+1)*(x_size+1) +typedef short Loc; + +//Simple structure for storing moves. This is a convenient place to define it. +STRUCT_NAMED_PAIR(Loc,loc,Player,pla,Move); + +#endif diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp new file mode 100644 index 000000000..bf4e30ff1 --- /dev/null +++ b/cpp/game/dotsfield.cpp @@ -0,0 +1,766 @@ +#include "board.h" +#include + +#include "../program/play.h" + +using namespace std; + +static constexpr int PLACED_PLAYER_SHIFT = PLAYER_BITS_COUNT; +static constexpr int EMPTY_TERRITORY_SHIFT = PLACED_PLAYER_SHIFT + PLAYER_BITS_COUNT; +static constexpr int TERRITORY_FLAG_SHIFT = EMPTY_TERRITORY_SHIFT + PLAYER_BITS_COUNT; +static constexpr int VISITED_FLAG_SHIFT = TERRITORY_FLAG_SHIFT + 1; + +static constexpr State TERRITORY_FLAG = 1 << TERRITORY_FLAG_SHIFT; +static constexpr State VISITED_FLAG = static_cast(1 << VISITED_FLAG_SHIFT); +static constexpr State INVALIDATE_TERRITORY_MASK = ~(ACTIVE_MASK | ACTIVE_MASK << EMPTY_TERRITORY_SHIFT); +static constexpr State INVALIDATE_VISITED_MASK = ~VISITED_FLAG; + +Loc Location::xm1y(Loc loc) { + return loc - 1; +} + +Loc Location::xm1ym1(Loc loc, int x_size) { + return loc - 1 - (x_size+1); +} + +Loc Location::xym1(Loc loc, int x_size) { + return loc - (x_size+1); +} + +Loc Location::xp1ym1(Loc loc, int x_size) { + return loc + 1 - (x_size+1); +} + +Loc Location::xp1y(Loc loc) { + return loc + 1; +} + +Loc Location::xp1yp1(Loc loc, int x_size) { + return loc + 1 + (x_size+1); +} + +Loc Location::xyp1(Loc loc, int x_size) { + return loc + (x_size+1); +} + +Loc Location::xm1yp1(Loc loc, int x_size) { + return loc - 1 + (x_size+1); +} + +inline int Location::getGetBigJumpInitialIndex(const Loc loc0, const Loc loc1, const int x_size) { + const int diff = loc1 - loc0; + const int stride = x_size + 1; + + if (diff == -1 || diff == -1 - stride) { + return RIGHT_TOP_INDEX; + } + + if(diff == -stride || diff == +1 - stride) { + return RIGHT_BOTTOM_INDEX; + } + + if(diff == 1 || diff == 1 + stride) { + return LEFT_BOTTOM_INDEX; + } + + if(diff == stride || diff == -1 + stride) { + return LEFT_TOP_INDEX; + } + + return -1; +} + +inline Loc Location::getNextLocCW(const Loc loc0, const Loc loc1, const int x_size) { + const int diff = loc1 - loc0; + const int stride = x_size + 1; + + if (diff == -1) return xm1ym1(loc0, x_size); + if (diff == -1 - stride) return xym1(loc0, x_size); + if (diff == -stride) return xp1ym1(loc0, x_size); + if (diff == +1 - stride) return xp1y(loc0); + if (diff == +1) return xp1yp1(loc0, x_size); + if (diff == +1 + stride) return xyp1(loc0, x_size); + if (diff == +stride) return xm1yp1(loc0, x_size); + if (diff == -1 + stride) return xm1y(loc0); + + assert(false && "Incorrect locations"); + return 0; +} + +Color getActiveColor(const State state) { + return static_cast(state & ACTIVE_MASK); +} + +inline bool isVisited(const State s) { + return (s & VISITED_FLAG) == VISITED_FLAG; +} + +inline bool isTerritory(const State s) { + return (s & TERRITORY_FLAG) == TERRITORY_FLAG; +} + +Color getPlacedDotColor(const State s) { + return static_cast(s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK); +} + +inline bool isPlaced(const State s, const Player pla) { + return (s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK) == pla; +} + +inline bool isActive(const State s, const Player pla) { + return (s & ACTIVE_MASK) == pla; +} + +inline State setTerritoryAndActivePlayer(const State s, const Player pla) { + return static_cast(TERRITORY_FLAG | (s & INVALIDATE_TERRITORY_MASK | pla)); +} + +inline Color getEmptyTerritoryColor(const State s) { + return static_cast(s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK); +} + +inline bool isWithinEmptyTerritory(const State s, const Player pla) { + return (s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK) == pla; +} + +inline State Board::getState(const Loc loc) const { + return colors[loc]; +} + +inline void Board::setState(const Loc loc, const State state) { + colors[loc] = state; +} + +bool Board::isDots() const { + return rules.isDots; +} + +inline void Board::setVisited(const Loc loc) { + colors[loc] = static_cast(colors[loc] | VISITED_FLAG); +} + +inline void Board::clearVisited(const Loc loc) { + colors[loc] = static_cast(colors[loc] & INVALIDATE_VISITED_MASK); +} + +inline void Board::clearVisited(const vector& locations) { + for (const Loc& loc : locations) { + clearVisited(loc); + } +} + +int Board::calculateGroundingWhiteScore(Color* result) const { + auto nonGroundedLocs = unordered_set(); + + const int whiteScoreAfterBlackGrounding = calculateGroundingWhiteScore(P_BLACK, nonGroundedLocs); + const int whiteScoreAfterWhiteGrounding = calculateGroundingWhiteScore(P_WHITE, nonGroundedLocs); + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + if (const Color color = getColor(loc); color == C_EMPTY || nonGroundedLocs.count(loc) > 0) { + result[loc] = C_EMPTY; + } else { + result[loc] = color; // Fill only grounded locs + } + } + } + + return whiteScoreAfterBlackGrounding + whiteScoreAfterWhiteGrounding; +} + +int Board::calculateGroundingWhiteScore(Player pla, unordered_set& nonGroundedLocs) const { + auto emptyBaseInvalidateLocations = vector(); + const auto bases = const_cast(this)->ground(pla, emptyBaseInvalidateLocations); + auto moveRecord = MoveRecord(PASS_LOC, pla, getState(PASS_LOC), bases, emptyBaseInvalidateLocations); + for (Base& base : moveRecord.bases) { + for (Loc& loc : base.rollback_locations) { + nonGroundedLocs.insert(loc); + } + } + + const int whiteScore = numBlackCaptures - numWhiteCaptures; + + const_cast(this)->undoDots(moveRecord); + return whiteScore; +} + +Board::MoveRecord::MoveRecord( + const Loc initLoc, + const Player initPla, + const State initPreviousState, + const vector& initBases, + const vector& initEmptyBaseInvalidateLocations +) { + ko_loc = NULL_LOC; + capDirs = 0; + + loc = initLoc; + pla = initPla; + previousState = initPreviousState; + bases = initBases; + emptyBaseInvalidateLocations = initEmptyBaseInvalidateLocations; +} + +bool Board::isSuicideDots(const Loc loc, const Player pla) const { + const State state = getState(loc); + if (Player opp = getOpp(pla); getActiveColor(state) == C_EMPTY && getEmptyTerritoryColor(state) == opp) { + return !wouldBeCaptureDots(loc, pla); + } + + return false; +} + +bool Board::wouldBeCaptureDots(const Loc loc, const Player pla) const { + // TODO: optimize and get rid of `const_cast` + auto moveRecord = const_cast(this)->tryPlayMove(loc, pla, false); + + bool result = false; + + if (moveRecord.pla != C_EMPTY) { + for (const Base& base : moveRecord.bases) { + if (base.is_real && base.pla == pla) { + result = true; + break; + } + } + + const_cast(this)->undoDots(moveRecord); + } + + return result; +} + +Board::MoveRecord Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { + MoveRecord result = tryPlayMove(loc, pla, true); + assert(result.pla == pla); + return std::move(result); +} + +Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLegal) { + State originalState = getState(loc); + + vector bases; + vector initEmptyBaseInvalidateLocations; + + if (loc == PASS_LOC) { + initEmptyBaseInvalidateLocations = vector(); + bases = ground(pla, initEmptyBaseInvalidateLocations); + } else { + colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); + const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; + pos_hash ^= hashValue; + numLegalMoves--; + + bases = tryCapture(loc, pla, false); + + const Color opp = getOpp(pla); + if (bases.empty()) { + if (getEmptyTerritoryColor(originalState) == opp) { + if (isSuicideLegal) { + bases.push_back(captureWhenEmptyTerritoryBecomesRealBase(loc, opp)); + } else { + colors[loc] = originalState; + pos_hash ^= hashValue; + numLegalMoves++; + return {}; + } + } + } else { + if (isWithinEmptyTerritory(originalState, opp)) { + invalidateAdjacentEmptyTerritoryIfNeeded(loc); + initEmptyBaseInvalidateLocations = vector(closureOrInvalidateLocsBuffer); + } + } + } + + return {loc, pla, originalState, bases, initEmptyBaseInvalidateLocations}; +} + +void Board::undoDots(MoveRecord& moveRecord) { + for (auto it = moveRecord.bases.rbegin(); it != moveRecord.bases.rend(); ++it) { + for (size_t index = 0; index < it->rollback_locations.size(); index++) { + const State rollbackState = it->rollback_states[index]; + const Loc rollbackLocation = it->rollback_locations[index]; + setState(rollbackLocation, rollbackState); + if (it->is_real) { + updateScoreAndHashForTerritory(rollbackLocation, rollbackState, it->pla, true); + } + } + } + + const bool isGrounding = moveRecord.loc == PASS_LOC; + + const Player emptyTerritoryPlayer = isGrounding ? moveRecord.pla : getOpp(moveRecord.pla); + for (const Loc& loc : moveRecord.emptyBaseInvalidateLocations) { + setState(loc, static_cast(emptyTerritoryPlayer << EMPTY_TERRITORY_SHIFT)); + } + + if (!isGrounding) { + setState(moveRecord.loc, moveRecord.previousState); + pos_hash ^= ZOBRIST_BOARD_HASH[moveRecord.loc][moveRecord.pla]; + numLegalMoves++; + } +} + +Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, const Player opp) { + Loc loc = initLoc; + + // Searching for an opponent dot that makes a closure that contains the `initialPosition`. + // The closure always exists, otherwise there is an error in previous calculations. + while (loc > 0) { + loc = Location::xm1y(loc); + + // Try to peek an active opposite player dot + if (getColor(loc) != opp) continue; + + vector oppBases = tryCapture(loc, opp, true); + // The found base always should be real and include the `iniLoc` + for (const Base& oppBase : oppBases) { + if (oppBase.is_real) { + return oppBase; + } + } + } + + assert(false && "Opp empty territory should be enclosed by an outer closure"); + return {}; +} + +vector Board::tryCapture(const Loc loc, const Player pla, const bool emptyBaseCapturing) { + getUnconnectedLocations(loc, pla); + auto currentClosures = vector>(); + + if (const int minNumberOfConnections = emptyBaseCapturing ? 1 : 2; + unconnectedLocationsBufferSize < minNumberOfConnections) return {}; + + for (int index = 0; index < unconnectedLocationsBufferSize; index++) { + Loc unconnectedLoc = unconnectedLocationsBuffer[index]; + + // Optimization: it doesn't make sense to check the latest unconnected dot + // when all previous connections form minimal bases + // because the latest always forms a base with maximal square that should be dropped + if (const size_t closuresSize = currentClosures.size(); + closuresSize > 0 && closuresSize == unconnectedLocationsBuffer.size() - 1) { + break; + } + + tryGetCounterClockwiseClosure(loc, unconnectedLoc, pla); + + // Sort the given closures in ascending order + if (!closureOrInvalidateLocsBuffer.empty()) { + bool added = false; + + auto newClosure = vector(closureOrInvalidateLocsBuffer); + + for (auto it = currentClosures.begin(); it != currentClosures.end(); ++it) { + if (closureOrInvalidateLocsBuffer.size() < it->size()) { + currentClosures.insert(it, newClosure); + added = true; + break; + } + } + + if (!added) { + currentClosures.emplace_back(newClosure); + } + } + } + + auto resultBases = vector(); + for (const vector& currentClosure: currentClosures) { + resultBases.push_back(buildBase(currentClosure, pla)); + } + return std::move(resultBases); +} + +vector Board::ground(const Player pla, vector& emptyBaseInvalidatePositions) { + auto processedLocs = vector(); + const Color opp = getOpp(pla); + auto resultBases = vector(); + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + + if (const State state = getState(loc); !isVisited(state) && isActive(state, pla)) { + bool createRealBase = false; + bool grounded = false; + getTerritoryLocations(pla, loc, true, createRealBase, grounded); + assert(createRealBase); + + if (!grounded) { + for (const Loc& territoryLoc : territoryLocationsBuffer) { + invalidateAdjacentEmptyTerritoryIfNeeded(territoryLoc); + for (const Loc& invalidateLoc : closureOrInvalidateLocsBuffer) { + emptyBaseInvalidatePositions.push_back(invalidateLoc); + } + } + + resultBases.push_back(createBaseAndUpdateStates(opp, createRealBase)); + } + + for (const Loc& territoryLoc : territoryLocationsBuffer) { + processedLocs.push_back(territoryLoc); + setVisited(territoryLoc); + } + } + } + } + + clearVisited(processedLocs); + + return resultBases; +} + +void Board::getUnconnectedLocations(const Loc loc, const Player pla) const { + const Loc xm1y = loc + adj_offsets[LEFT_INDEX]; + const Loc xym1 = loc + adj_offsets[TOP_INDEX]; + const Loc xp1y = loc + adj_offsets[RIGHT_INDEX]; + const Loc xyp1 = loc + adj_offsets[BOTTOM_INDEX]; + + unconnectedLocationsBufferSize = 0; + checkAndAddUnconnectedLocation(getColor(xp1y), pla, loc + adj_offsets[RIGHT_BOTTOM_INDEX], xyp1); + checkAndAddUnconnectedLocation(getColor(xyp1), pla, loc + adj_offsets[LEFT_BOTTOM_INDEX], xm1y); + checkAndAddUnconnectedLocation(getColor(xm1y), pla, loc + adj_offsets[LEFT_TOP_INDEX], xym1); + checkAndAddUnconnectedLocation(getColor(xym1), pla, loc + adj_offsets[RIGHT_TOP_INDEX], xp1y); +} + +inline void Board::checkAndAddUnconnectedLocation(const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { + if (checkPla != currentPla) { + if (getColor(addLoc1) == currentPla) { + unconnectedLocationsBuffer[unconnectedLocationsBufferSize++] = addLoc1; + } else if (getColor(addLoc2) == currentPla) { + unconnectedLocationsBuffer[unconnectedLocationsBufferSize++] = addLoc2; + } + } +} + +void Board::tryGetCounterClockwiseClosure(const Loc initialLoc, const Loc startLoc, const Player pla) { + closureOrInvalidateLocsBuffer.clear(); + closureOrInvalidateLocsBuffer.push_back(initialLoc); + setVisited(initialLoc); + closureOrInvalidateLocsBuffer.push_back(startLoc); + setVisited(startLoc); + + Loc currentLoc = startLoc; + Loc nextLoc = initialLoc; + Loc loc; + + do { + const int initialIndex = Location::getGetBigJumpInitialIndex(currentLoc, nextLoc, x_size); + int currentIndex = initialIndex; + + bool breakSearchingLoop = false; + + do { + loc = currentLoc + adj_offsets[currentIndex++]; + if (currentIndex == 8) currentIndex = 0; + + const State state = getState(loc); + const Color activeColor = getActiveColor(state); + if (activeColor == C_WALL) { + // Optimization: there is no need to walk anymore because the border can't enclosure anything + breakSearchingLoop = true; + break; + } + + if(activeColor == pla) { + if(loc == initialLoc) { + breakSearchingLoop = true; + break; + } + + if(isVisited(state)) { + // Remove trailing dots + Loc lastLoc; + do { + lastLoc = closureOrInvalidateLocsBuffer.back(); + closureOrInvalidateLocsBuffer.pop_back(); + clearVisited(lastLoc); + } while(lastLoc != loc); + } + + closureOrInvalidateLocsBuffer.push_back(loc); + setVisited(loc); + nextLoc = currentLoc; + currentLoc = loc; + break; + } + } while (currentIndex != initialIndex); + + if (breakSearchingLoop) { + break; + } + } + while (true); + + if (loc != initialLoc || closureOrInvalidateLocsBuffer.size() < 4) { + clearVisited(closureOrInvalidateLocsBuffer); + closureOrInvalidateLocsBuffer.clear(); + return; + } + + int square = 0; + const int stride = x_size + 1; + + const Loc prevLoc = closureOrInvalidateLocsBuffer.back(); + // Store the previously calculated coordinates because division is an expensive operation + int prevX = prevLoc % stride; + int prevY = prevLoc / stride; + + for (const Loc& l : closureOrInvalidateLocsBuffer) { + const int x = l % stride; + const int y = l / stride; + + square += prevY * x - y * prevX; + + prevX = x; + prevY = y; + clearVisited(l); + } + + if (square <= 0) { + closureOrInvalidateLocsBuffer.clear(); + } +} + +Board::Base Board::buildBase(const vector& closure, const Player pla) { + for (const Loc& closureLoc : closure) { + setVisited(closureLoc); + } + + const Loc territoryFirstLoc = Location::getNextLocCW(closure.at(1), closure.at(0), x_size); + bool createRealBase; + bool grounded = false; + getTerritoryLocations(pla, territoryFirstLoc, false, createRealBase, grounded); + assert(!grounded); + clearVisited(closure); + + return createBaseAndUpdateStates(pla, createRealBase); +} + +void Board::getTerritoryLocations(const Player pla, const Loc firstLoc, const bool grounding, bool& createRealBase, bool& grounded) { + walkStack.clear(); + territoryLocationsBuffer.clear(); + + createRealBase = grounding ? false : rules.dotsCaptureEmptyBases; + grounded = false; + const Player opp = getOpp(pla); + + State state = getState(firstLoc); + if (const Color activeColor = getActiveColor(state); activeColor == C_WALL) { + assert(grounding); + grounded = true; + } else { + bool legalLoc = false; + if (grounding) { + createRealBase = true; + legalLoc = activeColor == pla; + } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + legalLoc = true; // If no grounding, empty locations can be handled as well + } + + if (legalLoc) { + territoryLocationsBuffer.push_back(firstLoc); + setVisited(firstLoc); + walkStack.push_back(firstLoc); + } + } + + while (!walkStack.empty()) { + const Loc loc = walkStack.back(); + walkStack.pop_back(); + + FOREACHADJ( + Loc adj = loc + ADJOFFSET; + + state = getState(adj); + if (!isVisited(state)) { + const Color activeColor = getActiveColor(state); + if (activeColor == C_WALL) { + assert(grounding); + grounded = true; + } else { + bool isAdjLegal = false; + if (grounding) { + createRealBase = true; + isAdjLegal = activeColor == pla; + } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + isAdjLegal = true; // If no grounding, empty locations can be handled as well + } + + if (isAdjLegal) { + territoryLocationsBuffer.push_back(adj); + setVisited(adj); + walkStack.push_back(adj); + } + } + } + ) + } + + clearVisited(territoryLocationsBuffer); +} + +Board::Base Board::createBaseAndUpdateStates(Player basePla, bool isReal) { + auto rollbackLocations = vector(); + rollbackLocations.reserve(territoryLocationsBuffer.size()); + auto rollbackStates = vector(); + rollbackStates.reserve(territoryLocationsBuffer.size()); + + for (const Loc& territoryLoc : territoryLocationsBuffer) { + State state = getState(territoryLoc); + + if (Player activePlayer = getActiveColor(state); activePlayer != basePla) { + State newState; + if (isReal) { + updateScoreAndHashForTerritory(territoryLoc, state, basePla, false); + newState = setTerritoryAndActivePlayer(state, basePla); + } else { + newState = static_cast(basePla << EMPTY_TERRITORY_SHIFT); + } + + rollbackLocations.push_back(territoryLoc); + rollbackStates.push_back(state); + setState(territoryLoc, newState); + } + } + + return {basePla, rollbackLocations, rollbackStates, isReal}; +} + +void Board::updateScoreAndHashForTerritory(const Loc loc, const State state, const Player basePla, const bool rollback) { + const Color currentColor = getActiveColor(state); + const Player baseOppPla = getOpp(basePla); + + if (isPlaced(state, baseOppPla)) { + // The `getTerritoryPositions` never returns positions inside already owned territory, + // so there is no need to check for the territory flag. + if (basePla == P_BLACK) { + if (!rollback) { + numWhiteCaptures++; + } else { + numWhiteCaptures--; + } + } else { + if (!rollback) { + numBlackCaptures++; + } else { + numBlackCaptures--; + } + } + } else if (isPlaced(state, basePla) && isActive(state, baseOppPla) && rules.dotsFreeCapturedDots) { + // No diff for the territory of the current player + if (basePla == P_BLACK) { + if (!rollback) { + numBlackCaptures--; + } else { + numBlackCaptures++; + } + } else { + if (!rollback) { + numWhiteCaptures--; + } else { + numWhiteCaptures++; + } + } + } + + if (currentColor == C_EMPTY) { + if (!rollback) { + numLegalMoves--; + } else { + numLegalMoves++; + } + pos_hash ^= ZOBRIST_BOARD_HASH[loc][basePla]; + } else if (currentColor == baseOppPla) { + // Simulate unmaking the opponent move and making the player's move + const auto positionsHash = ZOBRIST_BOARD_HASH[loc]; + pos_hash ^= positionsHash[baseOppPla]; + pos_hash ^= positionsHash[basePla]; + } +} + +void Board::invalidateAdjacentEmptyTerritoryIfNeeded(const Loc loc) { + walkStack.clear(); + walkStack.push_back(loc); + closureOrInvalidateLocsBuffer.clear(); + + while(!walkStack.empty()) { + const Loc lastLoc = walkStack.back(); + walkStack.pop_back(); + + FOREACHADJ( + Loc adj = lastLoc + ADJOFFSET; + + State state = getState(adj); + if (getEmptyTerritoryColor(state) != C_EMPTY && !isVisited(state)) { + closureOrInvalidateLocsBuffer.push_back(adj); + setState(adj, C_EMPTY); + setVisited(adj); + + walkStack.push_back(adj); + } + ) + } + + clearVisited(closureOrInvalidateLocsBuffer); +} + +void Board::makeMoveAndCalculateCapturesAndBases( + const Player pla, + const Loc loc, + const bool isSuicideLegal, + vector& captures, + vector& bases + ) const { + if(isLegal(loc, pla, isSuicideLegal, false)) { + MoveRecord moveRecord = const_cast(this)->playMoveAssumeLegalDots(loc, pla); + + if(!moveRecord.bases.empty()) { + if(moveRecord.bases[0].pla == pla) { + // Better handling of empty bases? + captures[loc] = captures[loc] | moveRecord.bases[0].pla; + } + } + + for(Base& base: moveRecord.bases) { + for(const Loc& rollbackLoc: base.rollback_locations) { + // Consider empty bases as one move bases as well + bases[rollbackLoc] = bases[rollbackLoc] | base.pla; + } + } + + const_cast(this)->undo(moveRecord); + } +} + +void Board::calculateOneMoveCaptureAndBasePositionsForDots(const bool isSuicideLegal, vector& captures, vector& bases) const { + const int fieldSize = (x_size + 1) * (y_size + 1); + captures.resize(fieldSize); + bases.resize(fieldSize); + + for (int y = 0; y < y_size; y++) { + for (int x = 0; x < x_size; x++) { + const Loc loc = Location::getLoc(x, y, x_size); + + const State state = getState(loc); + const Color emptyTerritoryColor = getEmptyTerritoryColor(state); + if (emptyTerritoryColor != C_EMPTY) { + bases[loc] = bases[loc] | emptyTerritoryColor; + } + + // It doesn't make sense to calculate capturing when dot placed into own empty territory + if (emptyTerritoryColor != P_BLACK) { + makeMoveAndCalculateCapturesAndBases(P_BLACK, loc, isSuicideLegal, captures, bases); + } + + if (emptyTerritoryColor != P_WHITE) { + makeMoveAndCalculateCapturesAndBases(P_WHITE, loc, isSuicideLegal, captures, bases); + } + } + } +} diff --git a/cpp/game/graphhash.cpp b/cpp/game/graphhash.cpp index 3c5cb0a6a..988cf433e 100644 --- a/cpp/game/graphhash.cpp +++ b/cpp/game/graphhash.cpp @@ -2,6 +2,7 @@ Hash128 GraphHash::getStateHash(const BoardHistory& hist, Player nextPlayer, double drawEquivalentWinsForWhite) { const Board& board = hist.getRecentBoard(0); + assert(hist.rules.isDots == board.isDots()); Hash128 hash = BoardHistory::getSituationRulesAndKoHash(board, hist, nextPlayer, drawEquivalentWinsForWhite); // Fold in whether a pass ends this phase @@ -12,11 +13,13 @@ Hash128 GraphHash::getStateHash(const BoardHistory& hist, Player nextPlayer, dou if(hist.isGameFinished) hash ^= Board::ZOBRIST_GAME_IS_OVER; - // Fold in consecutive pass count. Probably usually redundant with history tracking. Use some standard LCG constants. - static constexpr uint64_t CONSECPASS_MULT0 = 2862933555777941757ULL; - static constexpr uint64_t CONSECPASS_MULT1 = 3202034522624059733ULL; - hash.hash0 += CONSECPASS_MULT0 * (uint64_t)hist.consecutiveEndingPasses; - hash.hash1 += CONSECPASS_MULT1 * (uint64_t)hist.consecutiveEndingPasses; + if (!hist.rules.isDots) { + // Fold in consecutive pass count. Probably usually redundant with history tracking. Use some standard LCG constants. + static constexpr uint64_t CONSECPASS_MULT0 = 2862933555777941757ULL; + static constexpr uint64_t CONSECPASS_MULT1 = 3202034522624059733ULL; + hash.hash0 += CONSECPASS_MULT0 * (uint64_t)hist.consecutiveEndingPasses; + hash.hash1 += CONSECPASS_MULT1 * (uint64_t)hist.consecutiveEndingPasses; + } return hash; } @@ -43,8 +46,10 @@ Hash128 GraphHash::getGraphHashFromScratch(const BoardHistory& histOrig, Player Hash128 graphHash = Hash128(); for(size_t i = 0; i +#include "board.h" + using namespace std; using json = nlohmann::json; -Rules::Rules() { - //Defaults if not set - closest match to TT rules - koRule = KO_POSITIONAL; - scoringRule = SCORING_AREA; - taxRule = TAX_NONE; - multiStoneSuicideLegal = true; - hasButton = false; - whiteHandicapBonusRule = WHB_ZERO; - friendlyPassOk = false; - komi = 7.5f; +const Rules Rules::DEFAULT_DOTS = Rules(true); +const Rules Rules::DEFAULT_GO = Rules(false); + +Rules::Rules() : Rules(false) {} + +Rules::Rules(const bool initIsDots, const int startPos, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : + Rules(initIsDots, startPos, 0, 0, 0, true, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} + +Rules::Rules(const bool initIsDots) : Rules( + initIsDots, + initIsDots ? START_POS_EMPTY : 0, + initIsDots ? 0 : KO_POSITIONAL, + initIsDots ? 0 : SCORING_AREA, + initIsDots ? 0 : TAX_NONE, + true, + false, + initIsDots ? 0 : WHB_ZERO, + false, + initIsDots ? 0.0f : 7.5f, + false, + initIsDots + ) { } Rules::Rules( @@ -28,53 +42,67 @@ Rules::Rules( int whbRule, bool pOk, float km -) - :koRule(kRule), - scoringRule(sRule), - taxRule(tRule), - multiStoneSuicideLegal(suic), - hasButton(button), - whiteHandicapBonusRule(whbRule), - friendlyPassOk(pOk), - komi(km) -{} - -Rules::~Rules() { +) : Rules(false, 0, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { } +Rules::Rules( + bool isDots, + int startPosRule, + int kRule, + int sRule, + int tRule, + bool suic, + bool button, + int whbRule, + bool pOk, + float km, + bool dotsCaptureEmptyBases, + bool dotsFreeCapturedDots +) + : isDots(isDots), + startPos(startPosRule), + dotsCaptureEmptyBases(dotsCaptureEmptyBases), + dotsFreeCapturedDots(dotsFreeCapturedDots), + koRule(kRule), + scoringRule(sRule), + taxRule(tRule), + multiStoneSuicideLegal(suic), + hasButton(button), + whiteHandicapBonusRule(whbRule), + friendlyPassOk(pOk), + komi(km) +{ + initializeIfNeeded(); +} + +Rules::~Rules() = default; + bool Rules::operator==(const Rules& other) const { - return - koRule == other.koRule && - scoringRule == other.scoringRule && - taxRule == other.taxRule && - multiStoneSuicideLegal == other.multiStoneSuicideLegal && - hasButton == other.hasButton && - whiteHandicapBonusRule == other.whiteHandicapBonusRule && - friendlyPassOk == other.friendlyPassOk && - komi == other.komi; + return equals(other, false); } bool Rules::operator!=(const Rules& other) const { - return - koRule != other.koRule || - scoringRule != other.scoringRule || - taxRule != other.taxRule || - multiStoneSuicideLegal != other.multiStoneSuicideLegal || - hasButton != other.hasButton || - whiteHandicapBonusRule != other.whiteHandicapBonusRule || - friendlyPassOk != other.friendlyPassOk || - komi != other.komi; + return !equals(other, false); } -bool Rules::equalsIgnoringKomi(const Rules& other) const { +bool Rules::equalsIgnoringSgfDefinedProps(const Rules& other) const { + return equals(other, true); +} + +bool Rules::equals(const Rules& other, const bool ignoreSgfDefinedProps) const { return + (ignoreSgfDefinedProps ? true : isDots == other.isDots) && + (ignoreSgfDefinedProps ? true : startPos == other.startPos) && koRule == other.koRule && scoringRule == other.scoringRule && taxRule == other.taxRule && multiStoneSuicideLegal == other.multiStoneSuicideLegal && hasButton == other.hasButton && whiteHandicapBonusRule == other.whiteHandicapBonusRule && - friendlyPassOk == other.friendlyPassOk; + friendlyPassOk == other.friendlyPassOk && + (ignoreSgfDefinedProps ? true : komi == other.komi) && + dotsCaptureEmptyBases == other.dotsCaptureEmptyBases && + dotsFreeCapturedDots == other.dotsFreeCapturedDots; } bool Rules::gameResultWillBeInteger() const { @@ -82,6 +110,14 @@ bool Rules::gameResultWillBeInteger() const { return komiIsInteger != hasButton; } +Rules Rules::getDefault(const bool isDots) { + return isDots ? DEFAULT_DOTS : DEFAULT_GO; +} + +Rules Rules::getDefaultOrTrompTaylorish(const bool isDots) { + return isDots ? DEFAULT_DOTS : getTrompTaylorish(); +} + Rules Rules::getTrompTaylorish() { Rules rules; rules.koRule = KO_POSITIONAL; @@ -112,6 +148,26 @@ bool Rules::komiIsIntOrHalfInt(float komi) { return std::isfinite(komi) && komi * 2 == (int)(komi * 2); } +set Rules::startPosStrings() { + initializeIfNeeded(); + return { + startPosIdToName[START_POS_EMPTY], + startPosIdToName[START_POS_CROSS], + startPosIdToName[START_POS_CROSS_2], + startPosIdToName[START_POS_CROSS_4] + }; +} + +int Rules::getNumOfStartPosStones() const { + switch (startPos) { + case START_POS_EMPTY: return 0; + case START_POS_CROSS: return 4; + case START_POS_CROSS_2: return 8; + case START_POS_CROSS_4: return 16; + default: throw std::range_error("Invalid start pos: " + std::to_string(startPos)); + } +} + set Rules::koRuleStrings() { return {"SIMPLE","POSITIONAL","SITUATIONAL","SPIGHT"}; } @@ -125,6 +181,15 @@ set Rules::whiteHandicapBonusRuleStrings() { return {"0","N","N-1"}; } +int Rules::parseStartPos(const string& s) { + initializeIfNeeded(); + if (const auto it = startPosNameToId.find(s); it != startPosNameToId.end()) { + return it->second; + } + + throw IOError("Rules::parseStartPos: Invalid dots start pos rule: " + s); +} + int Rules::parseKoRule(const string& s) { if(s == "SIMPLE") return Rules::KO_SIMPLE; else if(s == "POSITIONAL") return Rules::KO_POSITIONAL; @@ -132,17 +197,20 @@ int Rules::parseKoRule(const string& s) { else if(s == "SPIGHT") return Rules::KO_SPIGHT; else throw IOError("Rules::parseKoRule: Invalid ko rule: " + s); } + int Rules::parseScoringRule(const string& s) { if(s == "AREA") return Rules::SCORING_AREA; else if(s == "TERRITORY") return Rules::SCORING_TERRITORY; else throw IOError("Rules::parseScoringRule: Invalid scoring rule: " + s); } + int Rules::parseTaxRule(const string& s) { if(s == "NONE") return Rules::TAX_NONE; else if(s == "SEKI") return Rules::TAX_SEKI; else if(s == "ALL") return Rules::TAX_ALL; else throw IOError("Rules::parseTaxRule: Invalid tax rule: " + s); } + int Rules::parseWhiteHandicapBonusRule(const string& s) { if(s == "0") return Rules::WHB_ZERO; else if(s == "N") return Rules::WHB_N; @@ -150,6 +218,15 @@ int Rules::parseWhiteHandicapBonusRule(const string& s) { else throw IOError("Rules::parseWhiteHandicapBonusRule: Invalid whiteHandicapBonus rule: " + s); } +string Rules::writeStartPosRule(int startPosRule) { + initializeIfNeeded(); + if (const auto it = startPosIdToName.find(startPosRule); it != startPosIdToName.end()) { + return it->second; + } + + return "UNKNOWN"; +} + string Rules::writeKoRule(int koRule) { if(koRule == Rules::KO_SIMPLE) return string("SIMPLE"); if(koRule == Rules::KO_POSITIONAL) return string("POSITIONAL"); @@ -157,6 +234,7 @@ string Rules::writeKoRule(int koRule) { if(koRule == Rules::KO_SPIGHT) return string("SPIGHT"); return string("UNKNOWN"); } + string Rules::writeScoringRule(int scoringRule) { if(scoringRule == Rules::SCORING_AREA) return string("AREA"); if(scoringRule == Rules::SCORING_TERRITORY) return string("TERRITORY"); @@ -176,38 +254,42 @@ string Rules::writeWhiteHandicapBonusRule(int whiteHandicapBonusRule) { } ostream& operator<<(ostream& out, const Rules& rules) { - out << "ko" << Rules::writeKoRule(rules.koRule) - << "score" << Rules::writeScoringRule(rules.scoringRule) - << "tax" << Rules::writeTaxRule(rules.taxRule) - << "sui" << rules.multiStoneSuicideLegal; - if(rules.hasButton) - out << "button" << rules.hasButton; - if(rules.whiteHandicapBonusRule != Rules::WHB_ZERO) - out << "whb" << Rules::writeWhiteHandicapBonusRule(rules.whiteHandicapBonusRule); - if(rules.friendlyPassOk) - out << "fpok" << rules.friendlyPassOk; - out << "komi" << rules.komi; + out << rules.toString(); return out; } -string Rules::toStringNoKomi() const { - ostringstream out; - out << "ko" << Rules::writeKoRule(koRule) - << "score" << Rules::writeScoringRule(scoringRule) - << "tax" << Rules::writeTaxRule(taxRule) - << "sui" << multiStoneSuicideLegal; - if(hasButton) - out << "button" << hasButton; - if(whiteHandicapBonusRule != WHB_ZERO) - out << "whb" << Rules::writeWhiteHandicapBonusRule(whiteHandicapBonusRule); - if(friendlyPassOk) - out << "fpok" << friendlyPassOk; - return out.str(); +string Rules::toString() const { + return toString(true); } -string Rules::toString() const { +string Rules::toStringNoSgfDefinedProps() const { + return toString(false); +} + +string Rules::toString(const bool includeSgfDefinedProperties) const { ostringstream out; - out << (*this); + if (!isDots) { + out << "ko" << writeKoRule(koRule) + << "score" << writeScoringRule(scoringRule) + << "tax" << writeTaxRule(taxRule); + } else { + out << DOTS_CAPTURE_EMPTY_BASE_KEY << dotsCaptureEmptyBases; + } + if (includeSgfDefinedProperties && startPos != START_POS_EMPTY) { + out << START_POS_KEY << writeStartPosRule(startPos); + } + out << "sui" << multiStoneSuicideLegal; + if (!isDots) { + if (hasButton != DEFAULT_GO.hasButton) + out << "button" << hasButton; + if (whiteHandicapBonusRule != DEFAULT_GO.whiteHandicapBonusRule) + out << "whb" << writeWhiteHandicapBonusRule(whiteHandicapBonusRule); + if (friendlyPassOk != DEFAULT_GO.friendlyPassOk) + out << "fpok" << friendlyPassOk; + } + if (includeSgfDefinedProperties) { + out << "komi" << komi; + } return out.str(); } @@ -215,16 +297,25 @@ string Rules::toString() const { //which is the default for parsing and if not otherwise specified json Rules::toJsonHelper(bool omitKomi, bool omitDefaults) const { json ret; - ret["ko"] = writeKoRule(koRule); - ret["scoring"] = writeScoringRule(scoringRule); - ret["tax"] = writeTaxRule(taxRule); + if (isDots) + ret[DOTS_KEY] = true; + if (!omitDefaults || startPos != START_POS_EMPTY) + ret[START_POS_KEY] = writeStartPosRule(startPos); ret["suicide"] = multiStoneSuicideLegal; - if(!omitDefaults || hasButton) - ret["hasButton"] = hasButton; - if(!omitDefaults || whiteHandicapBonusRule != WHB_ZERO) - ret["whiteHandicapBonus"] = writeWhiteHandicapBonusRule(whiteHandicapBonusRule); - if(!omitDefaults || friendlyPassOk != false) - ret["friendlyPassOk"] = friendlyPassOk; + if (!isDots) { + ret["ko"] = writeKoRule(koRule); + ret["scoring"] = writeScoringRule(scoringRule); + ret["tax"] = writeTaxRule(taxRule); + if(!omitDefaults || hasButton != DEFAULT_GO.hasButton) + ret["hasButton"] = hasButton; + if(!omitDefaults || whiteHandicapBonusRule != DEFAULT_GO.whiteHandicapBonusRule) + ret["whiteHandicapBonus"] = writeWhiteHandicapBonusRule(whiteHandicapBonusRule); + if(!omitDefaults || friendlyPassOk != DEFAULT_GO.friendlyPassOk) + ret["friendlyPassOk"] = friendlyPassOk; + } else { + if (!omitDefaults || dotsCaptureEmptyBases != DEFAULT_DOTS.dotsCaptureEmptyBases) + ret[DOTS_CAPTURE_EMPTY_BASE_KEY] = dotsCaptureEmptyBases; + } if(!omitKomi) ret["komi"] = komi; return ret; @@ -258,7 +349,8 @@ Rules Rules::updateRules(const string& k, const string& v, Rules oldRules) { Rules rules = oldRules; string key = Global::trim(k); string value = Global::trim(Global::toUpper(v)); - if(key == "ko") rules.koRule = Rules::parseKoRule(value); + if(key == DOTS_KEY) rules.isDots = Global::stringToBool(value); + else if(key == "ko") rules.koRule = Rules::parseKoRule(value); else if(key == "score") rules.scoringRule = Rules::parseScoringRule(value); else if(key == "scoring") rules.scoringRule = Rules::parseScoringRule(value); else if(key == "tax") rules.taxRule = Rules::parseTaxRule(value); @@ -270,10 +362,12 @@ Rules Rules::updateRules(const string& k, const string& v, Rules oldRules) { return rules; } -static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { - Rules rules; +static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) { + auto rules = Rules(isDots); string lowercased = Global::trim(Global::toLower(sOrig)); - if(lowercased == "japanese" || lowercased == "korean") { + if(lowercased == DOTS_KEY) { + rules = Rules::DEFAULT_DOTS; + } else if(lowercased == "japanese" || lowercased == "korean") { rules.scoringRule = Rules::SCORING_TERRITORY; rules.koRule = Rules::KO_SIMPLE; rules.taxRule = Rules::TAX_SEKI; @@ -381,7 +475,7 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { } else if(sOrig.length() > 0 && sOrig[0] == '{') { //Default if not specified - rules = Rules::getTrompTaylorish(); + rules = Rules::getDefaultOrTrompTaylorish(isDots); bool komiSpecified = false; bool taxSpecified = false; try { @@ -389,11 +483,13 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { string s; for(json::iterator iter = input.begin(); iter != input.end(); ++iter) { string key = iter.key(); - if(key == "ko") + if (key == START_POS_KEY) + rules.startPos = Rules::parseStartPos(iter.value().get()); + else if (key == DOTS_CAPTURE_EMPTY_BASE_KEY) + rules.dotsCaptureEmptyBases = iter.value().get(); + else if(key == "ko") rules.koRule = Rules::parseKoRule(iter.value().get()); - else if(key == "score") - rules.scoringRule = Rules::parseScoringRule(iter.value().get()); - else if(key == "scoring") + else if(key == "score" || key == "scoring") rules.scoringRule = Rules::parseScoringRule(iter.value().get()); else if(key == "tax") { rules.taxRule = Rules::parseTaxRule(iter.value().get()); taxSpecified = true; @@ -443,7 +539,7 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { }; //Default if not specified - rules = Rules::getTrompTaylorish(); + rules = Rules::getDefaultOrTrompTaylorish(isDots); string s = sOrig; s = Global::trim(s); @@ -528,6 +624,12 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { else throw IOError("Could not parse rules: " + sOrig); continue; } + if(startsWithAndStrip(s,DOTS_CAPTURE_EMPTY_BASE_KEY)) { + if(startsWithAndStrip(s,"1")) rules.dotsCaptureEmptyBases = true; + else if(startsWithAndStrip(s,"0")) rules.dotsCaptureEmptyBases = false; + else throw IOError("Could not parse rules: " + sOrig); + continue; + } //Unknown rules format else throw IOError("Could not parse rules: " + sOrig); @@ -547,48 +649,173 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi) { } Rules Rules::parseRules(const string& sOrig) { - return parseRulesHelper(sOrig,true); + return parseRules(sOrig,false); } + +Rules Rules::parseRules(const string& sOrig, bool isDots) { + return parseRulesHelper(sOrig,true,isDots); +} + Rules Rules::parseRulesWithoutKomi(const string& sOrig, float komi) { - Rules rules = parseRulesHelper(sOrig,false); + return parseRulesWithoutKomi(sOrig,komi,false); +} + +Rules Rules::parseRulesWithoutKomi(const string& sOrig, float komi, bool isDots) { + Rules rules = parseRulesHelper(sOrig,false,isDots); rules.komi = komi; return rules; } -bool Rules::tryParseRules(const string& sOrig, Rules& buf) { +bool Rules::tryParseRules(const string& sOrig, Rules& buf, bool isDots) { Rules rules; - try { rules = parseRulesHelper(sOrig,true); } + try { rules = parseRulesHelper(sOrig,true,isDots); } catch(const StringError&) { return false; } buf = rules; return true; } -bool Rules::tryParseRulesWithoutKomi(const string& sOrig, Rules& buf, float komi) { + +bool Rules::tryParseRulesWithoutKomi(const string& sOrig, Rules& buf, float komi, bool isDots) { Rules rules; - try { rules = parseRulesHelper(sOrig,false); } + try { rules = parseRulesHelper(sOrig,false,isDots); } catch(const StringError&) { return false; } rules.komi = komi; buf = rules; return true; } -string Rules::toStringNoKomiMaybeNice() const { - if(equalsIgnoringKomi(parseRulesHelper("TrompTaylor",false))) +string Rules::toStringNoSgfDefinedPropertiesMaybeNice() const { + if(equalsIgnoringSgfDefinedProps(parseRulesHelper(DOTS_KEY, false, isDots))) + return "Dots"; + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("TrompTaylor",false, isDots))) return "TrompTaylor"; - if(equalsIgnoringKomi(parseRulesHelper("Japanese",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Japanese",false, isDots))) return "Japanese"; - if(equalsIgnoringKomi(parseRulesHelper("Chinese",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Chinese",false, isDots))) return "Chinese"; - if(equalsIgnoringKomi(parseRulesHelper("Chinese-OGS",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("Chinese-OGS",false, isDots))) return "Chinese-OGS"; - if(equalsIgnoringKomi(parseRulesHelper("AGA",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("AGA",false, isDots))) return "AGA"; - if(equalsIgnoringKomi(parseRulesHelper("StoneScoring",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("StoneScoring",false, isDots))) return "StoneScoring"; - if(equalsIgnoringKomi(parseRulesHelper("NewZealand",false))) + if(equalsIgnoringSgfDefinedProps(parseRulesHelper("NewZealand",false, isDots))) return "NewZealand"; - return toStringNoKomi(); + return toStringNoSgfDefinedProps(); +} + +std::vector Rules::generateStartPos(const int startPos, const int x_size, const int y_size) { + std::vector moves; + switch (startPos) { + case START_POS_EMPTY: + break; + case START_POS_CROSS: + if (x_size >= 2 && y_size >= 2) { + // Obey notago implementation for odd height + addCross((x_size + 1) / 2 - 1, y_size / 2 - 1, x_size, false, moves); + } + break; + case START_POS_CROSS_2: + if (x_size >= 4 && y_size >= 2) { + const int middleX = (x_size + 1) / 2; + const int middleY = y_size / 2 - 1; // Obey notago implementation for odd height + addCross(middleX - 2, middleY, x_size, false, moves); + addCross(middleX, middleY, x_size, true, moves); + } + break; + case START_POS_CROSS_4: + if (x_size >= 4 && y_size >= 4) { + int offsetX; + int offsetY; + if (x_size == 39 && y_size == 32) { + offsetX = 11; + offsetY = 10; + } else { + offsetX = (x_size - 3) / 3; + offsetY = (y_size - 3) / 3; + } + // Consider the index and size of the cross + const int sideOffsetX = (x_size - 1) - (offsetX + 1); + const int sideOffsetY = (y_size - 1) - (offsetY + 1); + addCross(offsetX, offsetY, x_size, false, moves); + addCross(sideOffsetX, offsetY, x_size, false, moves); + addCross(sideOffsetX, sideOffsetY, x_size, false, moves); + addCross(offsetX, sideOffsetY, x_size, false, moves); + } + break; + default: + throw std::range_error("Unsupported " + START_POS_KEY + ": " + to_string(startPos)); + } + + return moves; +} + +void Rules::addCross(const int x, const int y, const int x_size, const bool rotate90, std::vector& moves) { + Player pla; + Player opp; + + // The end move should always be P_WHITE + if (!rotate90) { + pla = P_WHITE; + opp = P_BLACK; + } else { + pla = P_BLACK; + opp = P_WHITE; + } + + const auto tailMove = Move(Location::getLoc(x, y + 1, x_size), pla); + + if (!rotate90) { + moves.push_back(tailMove); + } + + moves.emplace_back(Location::getLoc(x + 1, y + 1, x_size), opp); + moves.emplace_back(Location::getLoc(x + 1, y, x_size), pla); + moves.emplace_back(Location::getLoc(x, y, x_size), opp); + + if (rotate90) { + moves.push_back(tailMove); + } } +int Rules::tryRecognizeStartPos(int size_x, int size_y, vector& placementMoves, const bool emptyIfFailed) { + if(placementMoves.empty()) return START_POS_EMPTY; + + int result = emptyIfFailed ? START_POS_EMPTY : -1; + + // Sort locs because initial pos is invariant to moves order + auto sortByLoc = [&](vector& moves) { + std::sort(moves.begin(), moves.end(), [](const Move& move1, const Move& move2) { return move1.loc < move2.loc; }); + }; + + sortByLoc(placementMoves); + + auto generateStartPosSortAndCompare = [&](const int startPos) -> bool { + auto startPosMoves = generateStartPos(startPos, size_x, size_y); + + if(startPosMoves.size() != placementMoves.size()) { + return false; + } + + sortByLoc(startPosMoves); + + for(size_t i = 0; i < placementMoves.size(); i++) { + if(placementMoves[i].loc != startPosMoves[i].loc || placementMoves[i].pla != startPosMoves[i].pla) + return false; + } + + return true; + }; + + if(generateStartPosSortAndCompare(START_POS_CROSS)) { + result = START_POS_CROSS; + } else if(generateStartPosSortAndCompare(START_POS_CROSS_2)) { + result = START_POS_CROSS_2; + } else if(generateStartPosSortAndCompare(START_POS_CROSS_4)) { + result = START_POS_CROSS_4; + } + + return result; +} const Hash128 Rules::ZOBRIST_KO_RULE_HASH[4] = { Hash128(0x3cc7e0bf846820f6ULL, 0x1fb7fbde5fc6ba4eULL), //Based on sha256 hash of Rules::KO_SIMPLE @@ -618,3 +845,9 @@ const Hash128 Rules::ZOBRIST_BUTTON_HASH = //Based on sha256 hash of Rules::ZO const Hash128 Rules::ZOBRIST_FRIENDLY_PASS_OK_HASH = //Based on sha256 hash of Rules::ZOBRIST_FRIENDLY_PASS_OK_HASH Hash128(0x0113655998ef0a25ULL, 0x99c9d04ecd964874ULL); +const Hash128 Rules::ZOBRIST_DOTS_GAME_HASH = + Hash128(0xcdbfab9c91da83a9ULL, 0x6c2f198b2742181full); + +const Hash128 Rules::ZOBRIST_DOTS_CAPTURE_EMPTY_BASES_HASH = + Hash128(0x469afde424e960deull, 0x59f6138cebc753afull); + diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 1fffab834..72d085c68 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -1,12 +1,21 @@ #ifndef GAME_RULES_H_ #define GAME_RULES_H_ +#include "common.h" #include "../core/global.h" #include "../core/hash.h" #include "../external/nlohmann_json/json.hpp" struct Rules { + const static Rules DEFAULT_DOTS; + const static Rules DEFAULT_GO; + + static constexpr int START_POS_EMPTY = 0; + static constexpr int START_POS_CROSS = 1; + static constexpr int START_POS_CROSS_2 = 2; + static constexpr int START_POS_CROSS_4 = 3; + int startPos; static const int KO_SIMPLE = 0; static const int KO_POSITIONAL = 1; @@ -23,25 +32,30 @@ struct Rules { static const int TAX_ALL = 2; int taxRule; - bool multiStoneSuicideLegal; - bool hasButton; - static const int WHB_ZERO = 0; static const int WHB_N = 1; static const int WHB_N_MINUS_ONE = 2; int whiteHandicapBonusRule; - //Mostly an informational value - doesn't affect the actual implemented rules, but GTP or Analysis may, at a - //high level, use this info to adjust passing behavior - whether it's okay to pass without capturing dead stones. - //Only relevant for area scoring. - bool friendlyPassOk; - float komi; //Min and max acceptable komi in various places involving user input validation static constexpr float MIN_USER_KOMI = -150.0f; static constexpr float MAX_USER_KOMI = 150.0f; + bool isDots; + + bool dotsCaptureEmptyBases; + bool dotsFreeCapturedDots; // TODO: Implement later + bool multiStoneSuicideLegal; // Works as just suicide in Dots Game + bool hasButton; + //Mostly an informational value - doesn't affect the actual implemented rules, but GTP or Analysis may, at a + //high level, use this info to adjust passing behavior - whether it's okay to pass without capturing dead stones. + //Only relevant for area scoring. + bool friendlyPassOk; + Rules(); + Rules(bool initIsDots, int startPos, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); + explicit Rules(bool initIsDots); Rules( int koRule, int scoringRule, @@ -57,9 +71,12 @@ struct Rules { bool operator==(const Rules& other) const; bool operator!=(const Rules& other) const; - bool equalsIgnoringKomi(const Rules& other) const; - bool gameResultWillBeInteger() const; + [[nodiscard]] bool equalsIgnoringSgfDefinedProps(const Rules& other) const; + [[nodiscard]] bool equals(const Rules& other, bool ignoreSgfDefinedProps) const; + [[nodiscard]] bool gameResultWillBeInteger() const; + static Rules getDefault(bool isDots); + static Rules getDefaultOrTrompTaylorish(bool isDots); static Rules getTrompTaylorish(); static Rules getSimpleTerritory(); @@ -67,28 +84,45 @@ struct Rules { static std::set scoringRuleStrings(); static std::set taxRuleStrings(); static std::set whiteHandicapBonusRuleStrings(); + static int parseStartPos(const std::string& s); static int parseKoRule(const std::string& s); static int parseScoringRule(const std::string& s); static int parseTaxRule(const std::string& s); static int parseWhiteHandicapBonusRule(const std::string& s); + static std::string writeStartPosRule(int startPosRule); static std::string writeKoRule(int koRule); static std::string writeScoringRule(int scoringRule); static std::string writeTaxRule(int taxRule); static std::string writeWhiteHandicapBonusRule(int whiteHandicapBonusRule); static bool komiIsIntOrHalfInt(float komi); + static std::set startPosStrings(); + int getNumOfStartPosStones() const; - static Rules parseRules(const std::string& str); - static Rules parseRulesWithoutKomi(const std::string& str, float komi); - static bool tryParseRules(const std::string& str, Rules& buf); - static bool tryParseRulesWithoutKomi(const std::string& str, Rules& buf, float komi); + static Rules parseRules(const std::string& sOrig); + static Rules parseRules(const std::string& sOrig, bool isDots); + static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi); + static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi, bool isDots); + static bool tryParseRules(const std::string& sOrig, Rules& buf, bool isDots); + static bool tryParseRulesWithoutKomi(const std::string& sOrig, Rules& buf, float komi, bool isDots); static Rules updateRules(const std::string& key, const std::string& value, Rules priorRules); + static std::vector generateStartPos(int startPos, int x_size, int y_size); + /** + * @param size_x size of field + * @param size_y size of field + * @param placementMoves initial placement moves, they can be sorted, that's why it's not const + * @param emptyIfFailed + * @return -1 if the recognition is failed + */ + static int tryRecognizeStartPos(int size_x, int size_y, std::vector& placementMoves, bool emptyIfFailed); + friend std::ostream& operator<<(std::ostream& out, const Rules& rules); std::string toString() const; - std::string toStringNoKomi() const; - std::string toStringNoKomiMaybeNice() const; + std::string toStringNoSgfDefinedProps() const; + std::string toString(bool includeSgfDefinedProperties) const; + std::string toStringNoSgfDefinedPropertiesMaybeNice() const; std::string toJsonString() const; std::string toJsonStringNoKomi() const; std::string toJsonStringNoKomiMaybeOmitStuff() const; @@ -102,8 +136,43 @@ struct Rules { static const Hash128 ZOBRIST_MULTI_STONE_SUICIDE_HASH; static const Hash128 ZOBRIST_BUTTON_HASH; static const Hash128 ZOBRIST_FRIENDLY_PASS_OK_HASH; + static const Hash128 ZOBRIST_DOTS_GAME_HASH; + static const Hash128 ZOBRIST_DOTS_CAPTURE_EMPTY_BASES_HASH; private: + Rules( + bool isDots, + int startPosRule, + int kRule, + int sRule, + int tRule, + bool suic, + bool button, + int whbRule, + bool pOk, + float km, + bool dotsCaptureEmptyBases, + bool dotsFreeCapturedDots + ); + + static inline std::map startPosIdToName; + static inline std::map startPosNameToId; + + static void initializeIfNeeded() { + if (startPosIdToName.empty()) { + startPosIdToName[START_POS_EMPTY] = "EMPTY"; + startPosIdToName[START_POS_CROSS] = "CROSS"; + startPosIdToName[START_POS_CROSS_2] = "CROSS_2"; + startPosIdToName[START_POS_CROSS_4] = "CROSS_4"; + startPosNameToId["EMPTY"] = START_POS_EMPTY; + startPosNameToId["CROSS"] = START_POS_CROSS; + startPosNameToId["CROSS_2"] = START_POS_CROSS_2; + startPosNameToId["CROSS_4"] = START_POS_CROSS_4; + } + } + + static void addCross(int x, int y, int x_size, bool rotate90, std::vector& moves); + nlohmann::json toJsonHelper(bool omitKomi, bool omitDefaults) const; }; diff --git a/cpp/neuralnet/modelversion.cpp b/cpp/neuralnet/modelversion.cpp index e8ece86b1..746c9a27f 100644 --- a/cpp/neuralnet/modelversion.cpp +++ b/cpp/neuralnet/modelversion.cpp @@ -22,16 +22,20 @@ //15 = V7 features, Extra nonlinearity for pass output //16 = V7 features, Q value predictions in the policy head +//17 = V8 features (Dots game) + static void fail(int modelVersion) { throw StringError("NNModelVersion: Model version not currently implemented or supported: " + Global::intToString(modelVersion)); } static_assert(NNModelVersion::oldestModelVersionImplemented == 3, ""); static_assert(NNModelVersion::oldestInputsVersionImplemented == 3, ""); -static_assert(NNModelVersion::latestModelVersionImplemented == 16, ""); -static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); +static_assert(NNModelVersion::latestModelVersionImplemented == 17, ""); +static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); int NNModelVersion::getInputsVersion(int modelVersion) { + if (modelVersion == 17) + return dotsInputsVersion; if(modelVersion >= 8 && modelVersion <= 16) return 7; else if(modelVersion == 7) @@ -48,15 +52,17 @@ int NNModelVersion::getInputsVersion(int modelVersion) { } int NNModelVersion::getNumSpatialFeatures(int modelVersion) { + if(modelVersion == 17) + return NNInputs::NUM_FEATURES_SPATIAL_V_DOTS; if(modelVersion >= 8 && modelVersion <= 16) return NNInputs::NUM_FEATURES_SPATIAL_V7; - else if(modelVersion == 7) + if(modelVersion == 7) return NNInputs::NUM_FEATURES_SPATIAL_V6; - else if(modelVersion == 6) + if(modelVersion == 6) return NNInputs::NUM_FEATURES_SPATIAL_V5; - else if(modelVersion == 5) + if(modelVersion == 5) return NNInputs::NUM_FEATURES_SPATIAL_V4; - else if(modelVersion == 3 || modelVersion == 4) + if(modelVersion == 3 || modelVersion == 4) return NNInputs::NUM_FEATURES_SPATIAL_V3; fail(modelVersion); diff --git a/cpp/neuralnet/modelversion.h b/cpp/neuralnet/modelversion.h index 5961b7bd7..6eeac8c7a 100644 --- a/cpp/neuralnet/modelversion.h +++ b/cpp/neuralnet/modelversion.h @@ -4,9 +4,12 @@ // Model versions namespace NNModelVersion { - constexpr int latestModelVersionImplemented = 16; - constexpr int latestInputsVersionImplemented = 7; + constexpr int latestModelVersionImplemented = 17; + constexpr int latestInputsVersionImplemented = 8; + constexpr int latestGoInputsVersion = 7; + constexpr int dotsInputsVersion = latestInputsVersionImplemented; constexpr int defaultModelVersion = 16; + constexpr int defaultModelVersionForDots = 17; constexpr int oldestModelVersionImplemented = 3; constexpr int oldestInputsVersionImplemented = 3; diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index 595fe78dc..3fcf7d59a 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -66,7 +66,8 @@ NNEvaluator::NNEvaluator( const vector& gpuIdxByServerThr, const string& rSeed, bool doRandomize, - int defaultSymmetry + int defaultSymmetry, + bool dotsGame ) :modelName(mName), modelFileName(mFileName), @@ -106,7 +107,8 @@ NNEvaluator::NNEvaluator( currentDoRandomize(doRandomize), currentDefaultSymmetry(defaultSymmetry), currentBatchSize(maxBatchSz), - queryQueue() + queryQueue(), + dotsGame(dotsGame) { if(nnXLen > NNPos::MAX_BOARD_LEN) throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); @@ -148,7 +150,7 @@ NNEvaluator::NNEvaluator( } else { internalModelName = "random"; - modelVersion = NNModelVersion::defaultModelVersion; + modelVersion = dotsGame ? NNModelVersion::defaultModelVersionForDots : NNModelVersion::defaultModelVersion; inputsVersion = NNModelVersion::getInputsVersion(modelVersion); } @@ -786,19 +788,7 @@ void NNEvaluator::evaluate( if(buf.rowMetaBuf.size() < rowMetaLen) buf.rowMetaBuf.resize(rowMetaLen); - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - NNInputs::fillRowV3(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 4) - NNInputs::fillRowV4(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 5) - NNInputs::fillRowV5(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 6) - NNInputs::fillRowV6(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else if(inputsVersion == 7) - NNInputs::fillRowV7(board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); - else - ASSERT_UNREACHABLE; + NNInputs::fillRowVN(inputsVersion, board, history, nextPlayer, nnInputParams, nnXLen, nnYLen, inputsUseNHWC, buf.rowSpatialBuf.data(), buf.rowGlobalBuf.data()); if(rowMetaLen > 0) { if(sgfMeta == NULL) diff --git a/cpp/neuralnet/nneval.h b/cpp/neuralnet/nneval.h index 04ff7506b..6e87f8f0f 100644 --- a/cpp/neuralnet/nneval.h +++ b/cpp/neuralnet/nneval.h @@ -100,7 +100,8 @@ class NNEvaluator { const std::vector& gpuIdxByServerThread, const std::string& randSeed, bool doRandomize, - int defaultSymmetry + int defaultSymmetry, + bool dotsGame ); ~NNEvaluator(); @@ -269,6 +270,8 @@ class NNEvaluator { //Queued up requests ThreadSafeQueue queryQueue; + bool dotsGame; + public: //Helper, for internal use only void serve(NNServerBuf& buf, Rand& rand, int gpuIdxForThisThread, int serverThreadIdx); diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 5fad7ecc3..9f284b6ed 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -237,8 +237,18 @@ void NNInputs::fillScoring( bool groupTax, float* scoring ) { - if(!groupTax) { - std::fill(scoring, scoring + Board::MAX_ARR_SIZE, 0.0f); + std::fill_n(scoring, Board::MAX_ARR_SIZE, 0.0f); + + if(!groupTax || board.isDots()) { + // TODO: probably it makes sense to implement more accurate scoring for Dots + // That includes dead, empty locations and empty base locations: + // + // Captured enemy's dot: 1.0f + // Captured enemy's empty loc: 0.75f + // Empty base loc: 0.5f + // Empty: 0.0f + // + // Also consider grounding dots? for(int y = 0; y captures; + vector bases; + board.calculateOneMoveCaptureAndBasePositionsForDots(hist.rules.multiStoneSuicideLegal, captures, bases); + + Color grounding[Board::MAX_ARR_SIZE]; + board.calculateGroundingWhiteScore(grounding); + + auto boardString = board.toString(); + (void)boardString; + + for(int y = 0; y modelInfos; { ModelInfoForTuning modelInfo; diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 631b3cf3d..1481a6620 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -1,13 +1,13 @@ #include "../program/play.h" -#include "../core/global.h" #include "../core/fileutils.h" +#include "../core/global.h" #include "../core/timer.h" +#include "../dataio/files.h" #include "../program/playutils.h" #include "../program/setup.h" #include "../search/asyncbot.h" #include "../search/searchnode.h" -#include "../dataio/files.h" #include "../core/test.h" @@ -93,32 +93,38 @@ GameInitializer::GameInitializer(ConfigParser& cfg, Logger& logger, const string } void GameInitializer::initShared(ConfigParser& cfg, Logger& logger) { + dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); - allowedKoRuleStrs = cfg.getStrings("koRules", Rules::koRuleStrings()); - allowedScoringRuleStrs = cfg.getStrings("scoringRules", Rules::scoringRuleStrings()); - allowedTaxRuleStrs = cfg.getStrings("taxRules", Rules::taxRuleStrings()); - allowedMultiStoneSuicideLegals = cfg.getBools("multiStoneSuicideLegals"); - allowedButtons = cfg.getBools("hasButtons"); - - for(size_t i = 0; i < allowedKoRuleStrs.size(); i++) - allowedKoRules.push_back(Rules::parseKoRule(allowedKoRuleStrs[i])); - for(size_t i = 0; i < allowedScoringRuleStrs.size(); i++) - allowedScoringRules.push_back(Rules::parseScoringRule(allowedScoringRuleStrs[i])); - for(size_t i = 0; i < allowedTaxRuleStrs.size(); i++) - allowedTaxRules.push_back(Rules::parseTaxRule(allowedTaxRuleStrs[i])); - - if(allowedKoRules.size() <= 0) - throw IOError("koRules must have at least one value in " + cfg.getFileName()); - if(allowedScoringRules.size() <= 0) - throw IOError("scoringRules must have at least one value in " + cfg.getFileName()); - if(allowedTaxRules.size() <= 0) - throw IOError("taxRules must have at least one value in " + cfg.getFileName()); - if(allowedMultiStoneSuicideLegals.size() <= 0) - throw IOError("multiStoneSuicideLegals must have at least one value in " + cfg.getFileName()); - if(allowedButtons.size() <= 0) - throw IOError("hasButtons must have at least one value in " + cfg.getFileName()); + if (dotsGame) { + if (cfg.containsAny({"koRules", "scoringRules", "taxRules", "hasButtons"})) { + throw IOError("koRules, scoringRules, taxRules, hasButtons are not applicable for Dots game. Please remove them from " + cfg.getFileName()); + } + + allowedCaptureEmtpyBasesRules = cfg.getBools(DOTS_CAPTURE_EMPTY_BASES_KEY); + if (allowedCaptureEmtpyBasesRules.empty()) + throw IOError(DOTS_CAPTURE_EMPTY_BASES_KEY + " must have at least one value in " + cfg.getFileName()); + } else { + allowedKoRuleStrs = cfg.getStrings("koRules", Rules::koRuleStrings()); + allowedScoringRuleStrs = cfg.getStrings("scoringRules", Rules::scoringRuleStrings()); + allowedTaxRuleStrs = cfg.getStrings("taxRules", Rules::taxRuleStrings()); + allowedButtons = cfg.getBools("hasButtons"); + + for(size_t i = 0; i < allowedKoRuleStrs.size(); i++) + allowedKoRules.push_back(Rules::parseKoRule(allowedKoRuleStrs[i])); + for(size_t i = 0; i < allowedScoringRuleStrs.size(); i++) + allowedScoringRules.push_back(Rules::parseScoringRule(allowedScoringRuleStrs[i])); + for(size_t i = 0; i < allowedTaxRuleStrs.size(); i++) + allowedTaxRules.push_back(Rules::parseTaxRule(allowedTaxRuleStrs[i])); + + if(allowedKoRules.size() <= 0) + throw IOError("koRules must have at least one value in " + cfg.getFileName()); + if(allowedScoringRules.size() <= 0) + throw IOError("scoringRules must have at least one value in " + cfg.getFileName()); + if(allowedTaxRules.size() <= 0) + throw IOError("taxRules must have at least one value in " + cfg.getFileName()); + if(allowedButtons.size() <= 0) + throw IOError("hasButtons must have at least one value in " + cfg.getFileName()); - { bool hasAreaScoring = false; for(int i = 0; i or komiAuto=True in config"); - komiMean = cfg.contains("komiMean") ? cfg.getFloat("komiMean",Rules::MIN_USER_KOMI,Rules::MAX_USER_KOMI) : 7.5f; + komiMean = cfg.contains("komiMean") ? cfg.getFloat("komiMean",Rules::MIN_USER_KOMI,Rules::MAX_USER_KOMI) : + dotsGame ? 0.0f : 7.5f; komiStdev = cfg.contains("komiStdev") ? cfg.getFloat("komiStdev",0.0f,60.0f) : 0.0f; handicapProb = cfg.contains("handicapProb") ? cfg.getDouble("handicapProb",0.0,1.0) : 0.0; handicapCompensateKomiProb = cfg.contains("handicapCompensateKomiProb") ? cfg.getDouble("handicapCompensateKomiProb",0.0,1.0) : 0.0; @@ -243,6 +262,10 @@ void GameInitializer::initShared(ConfigParser& cfg, Logger& logger) { startPosesProb = 0.0; if(cfg.contains("startPosesFromSgfDir")) { + if (!allowedStartPosRules.empty()) { + throw StringError("startPosesFromSgfDir is not compatible with " + START_POSES_KEY + ". Please specify only one key."); + } + startPoses.clear(); startPosCumProbs.clear(); startPosesProb = cfg.getDouble("startPosesProb",0.0,1.0); @@ -466,16 +489,22 @@ Rules GameInitializer::createRules() { } Rules GameInitializer::createRulesUnsynchronized() { - Rules rules; - rules.koRule = allowedKoRules[rand.nextUInt((uint32_t)allowedKoRules.size())]; - rules.scoringRule = allowedScoringRules[rand.nextUInt((uint32_t)allowedScoringRules.size())]; - rules.taxRule = allowedTaxRules[rand.nextUInt((uint32_t)allowedTaxRules.size())]; - rules.multiStoneSuicideLegal = allowedMultiStoneSuicideLegals[rand.nextUInt((uint32_t)allowedMultiStoneSuicideLegals.size())]; - - if(rules.scoringRule == Rules::SCORING_AREA) - rules.hasButton = allowedButtons[rand.nextUInt((uint32_t)allowedButtons.size())]; - else - rules.hasButton = false; + auto rules = Rules(dotsGame); + rules.multiStoneSuicideLegal = allowedMultiStoneSuicideLegals[rand.nextUInt(static_cast(allowedMultiStoneSuicideLegals.size()))]; + if (!allowedStartPosRules.empty()) { + rules.startPos = allowedStartPosRules[rand.nextUInt(static_cast(allowedStartPosRules.size()))]; + } + + if (dotsGame) { + rules.dotsCaptureEmptyBases = allowedCaptureEmtpyBasesRules[rand.nextUInt(static_cast(allowedCaptureEmtpyBasesRules.size()))]; + } else { + rules.koRule = allowedKoRules[rand.nextUInt((uint32_t)allowedKoRules.size())]; + rules.scoringRule = allowedScoringRules[rand.nextUInt((uint32_t)allowedScoringRules.size())]; + rules.taxRule = allowedTaxRules[rand.nextUInt((uint32_t)allowedTaxRules.size())]; + rules.hasButton = rules.scoringRule == Rules::SCORING_AREA + ? allowedButtons[rand.nextUInt((uint32_t)allowedButtons.size())] + : false; + } return rules; } @@ -588,9 +617,10 @@ void GameInitializer::createGameSharedUnsynchronized( else { int xSize = allowedBSizes[bSizeIdx].first; int ySize = allowedBSizes[bSizeIdx].second; - board = Board(xSize,ySize); + board = Board(xSize,ySize,rules); pla = P_BLACK; hist.clear(board,pla,rules,0); + hist.setInitialTurnNumber(rules.getNumOfStartPosStones()); extraBlackAndKomi = PlayUtils::chooseExtraBlackAndKomi( komiMean, komiStdev, komiAllowIntegerProb, @@ -1316,7 +1346,7 @@ FinishedGameData* Play::runGame( assert(!(playSettings.forSelfPlay && !clearBotBeforeSearch)); if(extraBlackAndKomi.makeGameFairForEmptyBoard) { - Board b(startBoard.x_size,startBoard.y_size); + Board b(startBoard.x_size, startBoard.y_size, startBoard.rules); Player makeFairPla = P_BLACK; if(playSettings.flipKomiProbWhenNoCompensate != 0.0 && gameRand.nextBool(playSettings.flipKomiProbWhenNoCompensate)) makeFairPla = P_WHITE; @@ -1723,9 +1753,11 @@ FinishedGameData* Play::runGame( assert(gameData->finalFullArea == NULL); assert(gameData->finalOwnership == NULL); assert(gameData->finalSekiAreas == NULL); - gameData->finalFullArea = new Color[Board::MAX_ARR_SIZE]; gameData->finalOwnership = new Color[Board::MAX_ARR_SIZE]; - gameData->finalSekiAreas = new bool[Board::MAX_ARR_SIZE]; + if (!hist.rules.isDots) { + gameData->finalFullArea = new Color[Board::MAX_ARR_SIZE]; + gameData->finalSekiAreas = new bool[Board::MAX_ARR_SIZE]; + } if(hist.isGameFinished && hist.isNoResult) { finalValueTargets.win = 0.0f; @@ -1735,9 +1767,11 @@ FinishedGameData* Play::runGame( //Fill with empty so that we use "nobody owns anything" as the training target. //Although in practice actually the training normally weights by having a result or not, so it doesn't matter what we fill. - std::fill(gameData->finalFullArea,gameData->finalFullArea+Board::MAX_ARR_SIZE,C_EMPTY); - std::fill(gameData->finalOwnership,gameData->finalOwnership+Board::MAX_ARR_SIZE,C_EMPTY); - std::fill(gameData->finalSekiAreas,gameData->finalSekiAreas+Board::MAX_ARR_SIZE,false); + std::fill_n(gameData->finalOwnership, Board::MAX_ARR_SIZE,C_EMPTY); + if (!hist.rules.isDots) { + std::fill_n(gameData->finalFullArea, Board::MAX_ARR_SIZE,C_EMPTY); + std::fill_n(gameData->finalSekiAreas, Board::MAX_ARR_SIZE,false); + } } else { //Relying on this to be idempotent, so that we can get the final territory map @@ -1752,7 +1786,7 @@ FinishedGameData* Play::runGame( finalValueTargets.lead = finalValueTargets.score; //Fill full and seki areas - { + if (!hist.rules.isDots) { board.calculateArea(gameData->finalFullArea, true, true, true, hist.rules.multiStoneSuicideLegal); Color* independentLifeArea = new Color[Board::MAX_ARR_SIZE]; @@ -2415,12 +2449,12 @@ FinishedGameData* GameRunner::runGame( Search* botB; Search* botW; if(botSpecB.botIdx == botSpecW.botIdx) { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, hist.rules.isDots); botW = botB; } else { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B"); - botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W"); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", hist.rules.isDots); + botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", hist.rules.isDots); } if(afterInitialization != nullptr) { if(botSpecB.botIdx == botSpecW.botIdx) { @@ -2469,7 +2503,7 @@ FinishedGameData* GameRunner::runGame( assert(finishedGameData != NULL); Play::maybeForkGame(finishedGameData, forkData, playSettings, gameRand, botB); - if(!usedSekiForkHackPosition) { + if(!hist.rules.isDots && !usedSekiForkHackPosition) { Play::maybeSekiForkGame(finishedGameData, forkData, playSettings, gameInit, gameRand); } Play::maybeHintForkGame(finishedGameData, forkData, otherGameProps); diff --git a/cpp/program/play.h b/cpp/program/play.h index 8cadadf1f..d3935ac37 100644 --- a/cpp/program/play.h +++ b/cpp/program/play.h @@ -134,6 +134,10 @@ class GameInitializer { std::mutex createGameMutex; Rand rand; + bool dotsGame; + std::vector allowedCaptureEmtpyBasesRules; + std::vector allowedStartPosRules; + std::vector allowedKoRuleStrs; std::vector allowedScoringRuleStrs; std::vector allowedTaxRuleStrs; diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index b8ef70e5c..323e76dc6 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -311,7 +311,7 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { if(n > 9) throw StringError("Fixed handicap > 9 is not allowed"); - board = Board(xSize,ySize); + board = Board(xSize,ySize,board.rules); int xCoords[3]; //Corner, corner, side int yCoords[3]; //Corner, corner, side @@ -948,8 +948,8 @@ PlayUtils::BenchmarkResults PlayUtils::benchmarkSearchOnPositionsAndPrint( Rand seedRand; Search* bot = new Search(params,nnEval,nnEval->getLogger(),Global::uint64ToString(seedRand.nextUInt64())); - //Ignore the SGF rules, except for komi. Just use Tromp-taylor. - Rules initialRules = Rules::getTrompTaylorish(); + //Ignore the SGF rules, except for Dots and komi. Just use Tromp-taylor in case of Go. + Rules initialRules = Rules::getDefaultOrTrompTaylorish(sgf.isDots); //Take the komi from the sgf, otherwise ignore the rules in the sgf initialRules.komi = sgf.getRulesOrFailAllowUnspecified(initialRules).komi; diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index a31295145..35a5644c0 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -302,6 +302,7 @@ vector Setup::initializeNNEvaluators( if(disableFP16) useFP16Mode = enabled_t::False; + bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); NNEvaluator* nnEval = new NNEvaluator( nnModelName, nnModelFile, @@ -324,7 +325,8 @@ vector Setup::initializeNNEvaluators( gpuIdxByServerThread, nnRandSeed, (forcedSymmetry >= 0 ? false : nnRandomize), - defaultSymmetry + defaultSymmetry, + dotsGame ); nnEval->spawnServerThreads(); @@ -899,7 +901,7 @@ Rules Setup::loadSingleRules( if(cfg.contains("friendlyPassOk")) throw StringError("Cannot both specify 'rules' and individual rules like friendlyPassOk"); if(cfg.contains("whiteBonusPerHandicapStone")) throw StringError("Cannot both specify 'rules' and individual rules like whiteBonusPerHandicapStone"); - rules = Rules::parseRules(cfg.getString("rules")); + rules = Rules::parseRules(cfg.getString("rules"), cfg.getBoolOrDefault(DOTS_KEY, false)); } else { string koRule = cfg.getString("koRule", Rules::koRuleStrings()); diff --git a/cpp/search/asyncbot.cpp b/cpp/search/asyncbot.cpp index 0c18a5ef5..1a274c5f3 100644 --- a/cpp/search/asyncbot.cpp +++ b/cpp/search/asyncbot.cpp @@ -54,7 +54,7 @@ AsyncBot::AsyncBot( analyzeCallback(), searchBegunCallback() { - search = new Search(params,nnEval,humanEval,l,randSeed); + search = new Search(params,nnEval,humanEval,l,randSeed,false); searchThread = std::thread(searchThreadLoop,this,l); } diff --git a/cpp/search/localpattern.cpp b/cpp/search/localpattern.cpp index 8d075a06c..e66fdaa44 100644 --- a/cpp/search/localpattern.cpp +++ b/cpp/search/localpattern.cpp @@ -48,15 +48,18 @@ void LocalPatternHasher::init(int x, int y, Rand& rand) { LocalPatternHasher::~LocalPatternHasher() { } - Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) const { Hash128 hash = zobristPla[pla]; if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - const int dxi = board.adj_offsets[2]; - const int dyi = board.adj_offsets[3]; - assert(dxi == 1); - assert(dyi == board.x_size+1); + vector captures; + vector bases; + if (board.isDots()) { + // TODO: implement more faster version of `board.calculateOneMoveCaptureAndBasePositionsForDots(true, captures, bases);` + } + + const int dxi = 1; + const int dyi = board.x_size+1; int xRadius = xSize/2; int yRadius = ySize/2; @@ -74,9 +77,8 @@ Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) con int y2 = dy + yCenter; int x2 = dx + xCenter; int xy2 = y2 * xSize + x2; - hash ^= zobristLocalPattern[(int)board.colors[loc2] * xSize * ySize + xy2]; - if((board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) && board.getNumLiberties(loc2) == 1) - hash ^= zobristAtari[xy2]; + + updateHash(hash, board, bases, loc2, xy2, false); } } } @@ -89,10 +91,14 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p Hash128 hash = zobristPla[symPla]; if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - const int dxi = board.adj_offsets[2]; - const int dyi = board.adj_offsets[3]; - assert(dxi == 1); - assert(dyi == board.x_size+1); + vector captures; + vector bases; + if (board.isDots()) { + board.calculateOneMoveCaptureAndBasePositionsForDots(true, captures, bases); + } + + const int dxi = 1; + const int dyi = board.x_size+1; int xRadius = xSize/2; int yRadius = ySize/2; @@ -125,18 +131,38 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p symXY2 = symY2 * xSize + symX2; } - int symColor; - if(board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) - symColor = (int)(flipColors ? getOpp(board.colors[loc2]) : board.colors[loc2]); - else - symColor = (int)board.colors[loc2]; - - hash ^= zobristLocalPattern[symColor * xSize * ySize + symXY2]; - if((board.colors[loc2] == P_BLACK || board.colors[loc2] == P_WHITE) && board.getNumLiberties(loc2) == 1) - hash ^= zobristAtari[symXY2]; + updateHash(hash, board, bases, loc2, symXY2, flipColors); } } } return hash; } + +void LocalPatternHasher::updateHash( + Hash128& hash, + const Board& board, + const vector& bases, + const Loc loc, + const int patternXY, + const bool flipColors) const { + const Color colorAtLoc = board.getColor(loc); + Color newColor; + if (flipColors && (colorAtLoc == P_BLACK || colorAtLoc == P_WHITE)) { + newColor = getOpp(colorAtLoc); + } else { + newColor = colorAtLoc; + } + + hash ^= zobristLocalPattern[static_cast(newColor) * xSize * ySize + patternXY]; + + bool addAtariHash = false; + if(board.isDots()) { + //addAtariHash = bases[loc] != C_EMPTY; + addAtariHash = false; // TODO: implement for Dots + } else { + addAtariHash = (colorAtLoc == P_BLACK || colorAtLoc == P_WHITE) && board.getNumLiberties(loc) == 1; + } + if(addAtariHash) + hash ^= zobristAtari[patternXY]; +} \ No newline at end of file diff --git a/cpp/search/localpattern.h b/cpp/search/localpattern.h index 36ed0544b..67abc3b3a 100644 --- a/cpp/search/localpattern.h +++ b/cpp/search/localpattern.h @@ -23,6 +23,15 @@ struct LocalPatternHasher { //Returns the hash that would occur if symmetry were applied to both board and loc. //So basically, the only thing that changes is the zobrist indexing. Hash128 getHashWithSym(const Board& board, Loc loc, Player pla, int symmetry, bool flipColors) const; + +private: + void updateHash( + Hash128& hash, + const Board& board, + const std::vector& bases, + Loc loc, + int patternXY, + bool flipColors) const; }; #endif //SEARCH_LOCALPATTERN_H diff --git a/cpp/search/patternbonustable.cpp b/cpp/search/patternbonustable.cpp index 6367ecce3..4ce7ae5b3 100644 --- a/cpp/search/patternbonustable.cpp +++ b/cpp/search/patternbonustable.cpp @@ -258,7 +258,7 @@ void PatternBonusTable::avoidRepeatedPosMovesAndDeleteExcessFiles( turnNumber < minTurnNumber || turnNumber > maxTurnNumber || posSample.moves.size() != 0 || // Right now auto pattern avoid expects moveless records - !posSample.board.isLegal(posSample.hintLoc, posSample.nextPla, isMultiStoneSuicideLegal) + !posSample.board.isLegal(posSample.hintLoc, posSample.nextPla, isMultiStoneSuicideLegal, false) ) { numPosesInvalid += 1; continue; diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index b0bf240b4..5dff72cb2 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -66,9 +66,12 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed) - :Search(params,nnEval,NULL,lg,rSeed) + :Search(params,nnEval,NULL,lg,rSeed,false) {} -Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed) +Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed, const bool isDots) + :Search(params,nnEval,NULL,lg,rSeed,isDots) +{} +Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed, const bool isDots) :rootPla(P_BLACK), rootBoard(), rootHistory(), @@ -123,7 +126,12 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, throw StringError("Search::init - humanEval has different nnXLen or nnYLen"); } - rootKoHashTable = new KoHashTable(); + assert(rootHistory.rules.isDots == rootBoard.isDots()); + rootHistory.clear(rootBoard,rootPla,Rules::getDefault(isDots),0); + if (!isDots) { + rootKoHashTable = new KoHashTable(); + rootKoHashTable->recompute(rootHistory); + } rootSafeArea = new Color[Board::MAX_ARR_SIZE]; @@ -138,9 +146,6 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, rootNode = NULL; nodeTable = new SearchNodeTable(params.nodeTableShardsPowerOfTwo); mutexPool = new MutexPool(nodeTable->mutexPool->getNumMutexes()); - - rootHistory.clear(rootBoard,rootPla,Rules(),0); - rootKoHashTable->recompute(rootHistory); } Search::~Search() { @@ -181,7 +186,10 @@ void Search::setPosition(Player pla, const Board& board, const BoardHistory& his plaThatSearchIsFor = C_EMPTY; rootBoard = board; rootHistory = history; - rootKoHashTable->recompute(rootHistory); + assert(rootHistory.rules.isDots == rootBoard.isDots()); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } avoidMoveUntilByLocBlack.clear(); avoidMoveUntilByLocWhite.clear(); } @@ -197,7 +205,9 @@ void Search::setPlayerAndClearHistory(Player pla) { rootHistory.clear(rootBoard,rootPla,rules,rootHistory.encorePhase); rootHistory.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); - rootKoHashTable->recompute(rootHistory); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } //If changing the player alone, don't clear these, leave the user's setting - the user may have tried //to adjust the player or will be calling runWholeSearchAndGetMove with a different player and will @@ -336,7 +346,9 @@ bool Search::makeMove(Loc moveLoc, Player movePla, bool preventEncore) { //Compute these first so we can know if we need to set forceNonTerminal below. rootHistory.makeBoardMoveAssumeLegal(rootBoard,moveLoc,rootPla,rootKoHashTable,preventEncore); rootPla = getOpp(rootPla); - rootKoHashTable->recompute(rootHistory); + if (!rootHistory.rules.isDots) { + rootKoHashTable->recompute(rootHistory); + } if(rootNode != NULL) { SearchNode* child = NULL; diff --git a/cpp/search/search.h b/cpp/search/search.h index 562af59fd..5655be0cf 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -190,10 +190,16 @@ struct Search { Search( SearchParams params, NNEvaluator* nnEval, - NNEvaluator* humanEval, Logger* logger, - const std::string& randSeed - ); + const std::string& randSeed, + bool isDots); + Search( + SearchParams params, + NNEvaluator* nnEval, + NNEvaluator* humanEval, + Logger* lg, + const std::string& rSeed, + bool isDots); ~Search(); Search(const Search&) = delete; diff --git a/cpp/search/searchexplorehelpers.cpp b/cpp/search/searchexplorehelpers.cpp index 59e9b7483..6b1ced47d 100644 --- a/cpp/search/searchexplorehelpers.cpp +++ b/cpp/search/searchexplorehelpers.cpp @@ -639,5 +639,4 @@ void Search::selectBestChildToDescend( thread.shouldCountPlayout = false; } } - } diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index fa62b94e5..a88a682de 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -13,10 +13,10 @@ void Tests::runBoardIOTests() { //============================================================================ { const char* name = "Location parse test"; - auto testLoc = [&out](const char* s, int xSize, int ySize) { + auto testLoc = [&out](const char* s, int xSize, int ySize, bool isDots) { try { Loc loc = Location::ofString(s,xSize,ySize); - out << s << " " << Location::toString(loc,xSize,ySize) << " x " << Location::getX(loc,xSize) << " y " << Location::getY(loc,xSize) << endl; + out << s << " " << Location::toString(loc, xSize, ySize, isDots) << " x " << Location::getX(loc,xSize) << " y " << Location::getY(loc,xSize) << endl; } catch(const StringError& e) { out << e.what() << endl; @@ -34,27 +34,28 @@ void Tests::runBoardIOTests() { out << "----------------------------------" << endl; out << xSize << " " << ySize << endl; - testLoc("A1",xSize,ySize); - testLoc("A0",xSize,ySize); - testLoc("B2",xSize,ySize); - testLoc("b2",xSize,ySize); - testLoc("A",xSize,ySize); - testLoc("B",xSize,ySize); - testLoc("1",xSize,ySize); - testLoc("pass",xSize,ySize); - testLoc("H9",xSize,ySize); - testLoc("I9",xSize,ySize); - testLoc("J9",xSize,ySize); - testLoc("J10",xSize,ySize); - testLoc("K8",xSize,ySize); - testLoc("k19",xSize,ySize); - testLoc("a22",xSize,ySize); - testLoc("y1",xSize,ySize); - testLoc("z1",xSize,ySize); - testLoc("aa1",xSize,ySize); - testLoc("AA26",xSize,ySize); - testLoc("AZ26",xSize,ySize); - testLoc("BC50",xSize,ySize); + testLoc("A1",xSize,ySize,false); + testLoc("A0",xSize,ySize,false); + testLoc("B2",xSize,ySize,false); + testLoc("b2",xSize,ySize,false); + testLoc("A",xSize,ySize,false); + testLoc("B",xSize,ySize,false); + testLoc("1",xSize,ySize,false); + testLoc("pass",xSize,ySize,false); + testLoc("ground",xSize,ySize,true); + testLoc("H9",xSize,ySize,false); + testLoc("I9",xSize,ySize,false); + testLoc("J9",xSize,ySize,false); + testLoc("J10",xSize,ySize,false); + testLoc("K8",xSize,ySize,false); + testLoc("k19",xSize,ySize,false); + testLoc("a22",xSize,ySize,false); + testLoc("y1",xSize,ySize,false); + testLoc("z1",xSize,ySize,false); + testLoc("aa1",xSize,ySize,false); + testLoc("AA26",xSize,ySize,false); + testLoc("AZ26",xSize,ySize,false); + testLoc("BC50",xSize,ySize,false); } } @@ -69,6 +70,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 0 Could not parse board location: I9 J9 J9 x 8 y 0 @@ -92,6 +94,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -115,6 +118,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 0 Could not parse board location: I9 J9 J9 x 8 y 0 @@ -138,6 +142,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -161,6 +166,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -184,6 +190,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 10 Could not parse board location: I9 J9 J9 x 8 y 10 @@ -207,6 +214,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -230,6 +238,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 61 Could not parse board location: I9 J9 J9 x 8 y 61 @@ -253,6 +262,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 17 Could not parse board location: I9 J9 J9 x 8 y 17 @@ -276,6 +286,7 @@ Could not parse board location: A Could not parse board location: B Could not parse board location: 1 pass pass x 0 y -1 +ground ground x 0 y -1 H9 H9 x 7 y 61 Could not parse board location: I9 J9 J9 x 8 y 61 @@ -1917,7 +1928,7 @@ void Tests::runBoardUndoTest() { //Maximum range of board location values when 19x19: int numLocs = (19+1)*(19+2)+1; loc = (Loc)rand.nextUInt(numLocs); - if(boards[n].isLegal(loc,pla,multiStoneSuicideLegal)) + if(boards[n].isLegal(loc, pla, multiStoneSuicideLegal, false)) break; } @@ -2138,7 +2149,7 @@ void Tests::runBoardStressTest() { bool isLegal[numBoards]; bool suc[numBoards]; for(int i = 0; i + +#include "../tests/tests.h" +#include "../tests/testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace std::chrono; +using namespace TestCommon; + +void checkDotsField(const string& description, const string& input, bool captureEmptyBases, bool freeCapturedDots, const std::function& check) { + cout << " " << description << endl; + + auto moveRecords = vector(); + + Board initialBoard = parseDotsField(input, captureEmptyBases, freeCapturedDots, {}); + + Board board = Board(initialBoard); + + BoardWithMoveRecords boardWithMoveRecords = BoardWithMoveRecords(board, moveRecords); + check(boardWithMoveRecords); + + while (!moveRecords.empty()) { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } + testAssert(initialBoard.isEqualForTesting(board, true, true)); +} + +void checkDotsFieldDefault(const string& description, const string& input, const std::function& check) { + checkDotsField(description, input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, check); +} + +void Tests::runDotsFieldTests() { + cout << "Running dots basic tests: " << endl; + + checkDotsFieldDefault("Simple capturing", + R"( +.x. +xox +... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsFieldDefault("Capturing with empty loc inside", + R"( +.oo. +ox.. +.oo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(boardWithMoveRecords.isLegal(2, 1, P_BLACK)); + testAssert(boardWithMoveRecords.isLegal(2, 1, P_WHITE)); + + boardWithMoveRecords.playMove(3, 1, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(!boardWithMoveRecords.isLegal(2, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isLegal(2, 1, P_WHITE)); +}); + + checkDotsFieldDefault("Triple capture", + R"( +.x.x. +xo.ox +.xox. +..x.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 1, P_BLACK); + testAssert(3 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsFieldDefault("Base inside base inside base", + R"( +.xxxxxxx. +x..ooo..x +x.o.x.o.x +x.oxoxo.x +x.o...o.x +x..o.o..x +.xxx.xxx. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(4, 4, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 5, P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(4 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 6, P_BLACK); + testAssert(13 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +}); + + /*checkDotsField("Base inside base inside base don't free captured dots", + R"( +.xxxxxxxxx.. +x..oooooo.x. +x.o.xx...o.x +x.oxo.xo.o.x +x.o.x.o..o.x +x..o....ox.x +x...o.oo...x +.xxxx.xxxxx. +)", true, false, [](const BoardWithMoveRecords& boardWithMoveRecords) { +boardWithMoveRecords.playMove(5, 4, P_BLACK); +testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + +boardWithMoveRecords.playMove(5, 6, P_WHITE); +testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); // Don't free the captured dot +testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); // Ignore owned color dots + +boardWithMoveRecords.playMove(5, 7, P_BLACK); +testAssert(21 == boardWithMoveRecords.board.numWhiteCaptures); // Don't count already counted dots +testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); // Don't free the captured dot +});*/ + + checkDotsFieldDefault("Empty bases and suicide", + R"( +.x..o. +x.xo.o +.x..o. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + // Suicide move is not capture + testAssert(!boardWithMoveRecords.wouldBeCapture(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.wouldBeCapture(1, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.wouldBeCapture(4, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.wouldBeCapture(4, 1, P_BLACK)); + + testAssert(boardWithMoveRecords.isSuicide(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_BLACK)); + boardWithMoveRecords.playMove(1, 1, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(boardWithMoveRecords.isSuicide(4, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_WHITE)); + boardWithMoveRecords.playMove(4, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); +}); + + checkDotsField("Empty bases when they are allowed", + R"( +.x..o. +x.xo.o +...... +)", true, true, [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 2, P_BLACK); + boardWithMoveRecords.playMove(4, 2, P_WHITE); + + // Suicide is not possible in this mode + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_WHITE)); + testAssert(!boardWithMoveRecords.isSuicide(1, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_BLACK)); + testAssert(!boardWithMoveRecords.isSuicide(4, 1, P_WHITE)); + + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +}); + + checkDotsFieldDefault("Capture wins suicide", + R"( +.xo. +xo.o +.xo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(!boardWithMoveRecords.isSuicide(2, 1, P_BLACK)); + boardWithMoveRecords.playMove(2, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsFieldDefault("Single dot doesn't break searching inside empty base", + R"( +.oooo. +o....o +o.o..o +o....o +.oooo. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(4, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + }); + + checkDotsFieldDefault("Ignored already surrounded territory", + R"( +..xxx... +.x...x.. +x..x..x. +x.x.x..x +x..x..x. +.x...x.. +..xxx... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playMove(6, 3, P_WHITE); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); +}); + + checkDotsFieldDefault("Invalidation of empty base locations", + R"( +.oox. +o..ox +.oox. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 1, P_BLACK); + boardWithMoveRecords.playMove(1, 1, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsFieldDefault("Invalidation of empty base locations ignoring borders", + R"( +..xxx.... +.x...x... +x..x..xo. +x.x.x..xo +x..x..xo. +.x...x... +..xxx.... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(6, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(1, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(3, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsFieldDefault("Dangling dots removing", + R"( +.xx.xx. +x..xo.x +x.x.x.x +x..x..x +.x...x. +..x.x.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 5, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + testAssert(!boardWithMoveRecords.isLegal(3, 2, P_BLACK)); + testAssert(!boardWithMoveRecords.isLegal(3, 2, P_WHITE)); + }); + + checkDotsFieldDefault("Recalculate square during dangling dots removing", + R"( +.ooo.. +o...o. +o.o..o +..xo.o +o.o..o +o...o. +.ooo.. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(1, 3, P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + + boardWithMoveRecords.playMove(4, 3, P_BLACK); + testAssert(2 == boardWithMoveRecords.board.numBlackCaptures); + }); + + checkDotsFieldDefault("Base sorting by size", + R"( +..xxx.. +.x...x. +x..x..x +x.xox.x +x.....x +.xx.xx. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 4, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playMove(4, 1, P_WHITE); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); + }); + + checkDotsFieldDefault("Number of legal moves", + R"( +.... +.... +.... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { +testAssert(12 == boardWithMoveRecords.board.numLegalMoves); +}); + + checkDotsFieldDefault("Game over because of absence of legal moves", + R"( +xxxx +xo.x +xx.x +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + testAssert(0 == boardWithMoveRecords.board.numLegalMoves); + }); +} + +void Tests::runDotsGroundingTests() { + cout << "Running dots grounding tests:" << endl; + + checkDotsFieldDefault("Simple", + R"( +..... +.xxo. +..... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(2 == boardWithMoveRecords.board.numBlackCaptures); + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + boardWithMoveRecords.undo(); + } +); + + checkDotsFieldDefault("Draw", +R"( +.x... +.xxo. +...o. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + boardWithMoveRecords.undo(); +} +); + + checkDotsFieldDefault("Bases", +R"( +......... +..xx...x. +.xo.x.xox +..x...... +......... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(3, 3, P_BLACK); + boardWithMoveRecords.playMove(7, 3, P_BLACK); + testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); + + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); +} +); + + checkDotsFieldDefault("Multiple groups", +R"( +...... +xxo..o +.ox... +x...oo +...o.. +...... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playGroundingMove(P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + boardWithMoveRecords.undo(); + + boardWithMoveRecords.playGroundingMove(P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); + testAssert(3 == boardWithMoveRecords.board.numWhiteCaptures); + boardWithMoveRecords.undo(); +} +); + + checkDotsFieldDefault("Invalidate empty territory", +R"( +...... +..oo.. +.o..o. +..oo.. +...... +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { +Board board = boardWithMoveRecords.board; + +State state = boardWithMoveRecords.board.getState(Location::getLoc(2, 2, board.x_size)); +testAssert(C_WHITE == getEmptyTerritoryColor(state)); + +state = boardWithMoveRecords.board.getState(Location::getLoc(3, 2, board.x_size)); +testAssert(C_WHITE == getEmptyTerritoryColor(state)); + +boardWithMoveRecords.playGroundingMove(P_WHITE); +testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +testAssert(6 == boardWithMoveRecords.board.numWhiteCaptures); + +state = boardWithMoveRecords.board.getState(Location::getLoc(2, 2, board.x_size)); +testAssert(C_EMPTY == getEmptyTerritoryColor(state)); + +state = boardWithMoveRecords.board.getState(Location::getLoc(3, 2, board.x_size)); +testAssert(C_EMPTY == getEmptyTerritoryColor(state)); +} +); + + checkDotsFieldDefault("Don't invalidate empty territory for strong connection", +R"( +.x. +x.x +.x. +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { +Board board = boardWithMoveRecords.board; + +boardWithMoveRecords.playGroundingMove(P_BLACK); +testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); +testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); + +State state = boardWithMoveRecords.board.getState(Location::getLoc(1, 1, board.x_size)); +testAssert(C_BLACK == getEmptyTerritoryColor(state)); + +state = boardWithMoveRecords.board.getState(Location::getLoc(0, 0, board.x_size)); +testAssert(C_EMPTY == getEmptyTerritoryColor(state)); +} +); +} + +void Tests::runDotsPosHashTests() { + cout << "Running dots pos hash tests" << endl; + + { + Board dotsFieldWithEmptyBase = parseDotsFieldDefault(R"( +.x. +x.x +.x. +)", { XYMove(1, 1, P_WHITE) }); + + Board dotsFieldWithRealBase = parseDotsFieldDefault(R"( +.x. +xox +... +)", { XYMove(1, 2, P_BLACK) }); + + testAssert(dotsFieldWithEmptyBase.pos_hash == dotsFieldWithRealBase.pos_hash); + } + + { + Board dotsFieldWithSurrounding = parseDotsFieldDefault(R"( +..xxxxxx.. +.x......x. +x..x..o..x +x.xoxoxo.x +x........x +.x......x. +..xxx.xx.. +)", { XYMove(3, 4, P_BLACK), XYMove(6, 4, P_WHITE), XYMove(5, 6, P_BLACK) }); + testAssert(5 == dotsFieldWithSurrounding.numWhiteCaptures); + testAssert(0 == dotsFieldWithSurrounding.numBlackCaptures); + + Board dotsFieldWithErasedTerritory = parseDotsFieldDefault(R"( +..xxxxxx.. +.xxxxxxxx. +xxxxxxxxxx +xxxxxxxxxx +xxxxxxxxxx +.xxxxxxxx. +..xxxxxx.. +)"); + + testAssert(dotsFieldWithErasedTerritory.pos_hash == dotsFieldWithSurrounding.pos_hash); + } +} + + + +void runDotsStressTestsInternal(int x_size, int y_size, int gamesCount, float groundingAfterCoef, float groundingProb, float komi, bool suicideAllowed, bool checkRollback) { + // TODO: add tests with grounding + cout << " Random games" << endl; + cout << " Check rollback: " << boolalpha << checkRollback << endl; +#ifdef NDEBUG + cout << " Build: Release" << endl; +#else + cout << " Build: Debug" << endl; +#endif + cout << " Size: " << x_size << ":" << y_size << endl; + cout << " Komi: " << komi << endl; + cout << " Suicide: " << boolalpha << suicideAllowed << endl; + cout << " Games count: " << gamesCount << endl; + + const auto start = high_resolution_clock::now(); + + Rand rand("runDotsStressTests"); + + Rules rules = Rules(false); + Board initialBoard = Board(x_size, y_size, rules); + + int tryGroundingAfterMove = groundingAfterCoef * initialBoard.numLegalMoves; + + vector randomMoves = vector(); + randomMoves.reserve(initialBoard.numLegalMoves); + + for(int y = 0; y < initialBoard.y_size; y++) { + for(int x = 0; x < initialBoard.x_size; x++) { + randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); + } + } + + int movesCount = 0; + int blackWinsCount = 0; + int whiteWinsCount = 0; + int drawsCount = 0; + + auto moveRecords = vector(); + + for (int n = 0; n < gamesCount; n++) { + rand.shuffle(randomMoves); + + auto board = Board(initialBoard.x_size, initialBoard.y_size, rules); + + Player pla = P_BLACK; + for (Loc loc : randomMoves) { + if (board.isLegal(loc, pla, suicideAllowed, false)) { + Board::MoveRecord moveRecord = board.playMoveRecorded(loc, pla); + movesCount++; + if (checkRollback) { + moveRecords.push_back(moveRecord); + } + pla = getOpp(pla); + } + } + + /*if (suicideAllowed) { + testAssert(0 == board.numLegalMoves); + }*/ + + if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { + whiteWinsCount++; + } else if (whiteScore < 0) { + blackWinsCount++; + } else { + drawsCount++; + } + + if (checkRollback) { + while (!moveRecords.empty()) { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } + + testAssert(initialBoard.isEqualForTesting(board, true, false)); + } + } + + const auto end = high_resolution_clock::now(); + auto durationNs = duration_cast(end - start); + + cout.precision(4); + cout << " Elapsed time: " << duration_cast(durationNs).count() << " ms" << endl; + cout << " Number of games per second: " << static_cast(static_cast(gamesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per second: " << static_cast(static_cast(movesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per game: " << static_cast(static_cast(movesCount) / gamesCount) << endl; + cout << " Time per game: " << static_cast(durationNs.count()) / gamesCount / 1000000 << " ms" << endl; + cout << " Black wins: " << blackWinsCount << " (" << static_cast(blackWinsCount) / gamesCount << ")" << endl; + cout << " White wins: " << whiteWinsCount << " (" << static_cast(whiteWinsCount) / gamesCount << ")" << endl; + cout << " Draws: " << drawsCount << " (" << static_cast(drawsCount) / gamesCount << ")" << endl; +} + +void Tests::runDotsStressTests() { + cout << "Running dots stress tests" << endl; + + cout << " Max territory" << endl; + Board board = Board(39, 32, Rules::DEFAULT_DOTS); + for(int y = 0; y < board.y_size; y++) { + for(int x = 0; x < board.x_size; x++) { + Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; + board.playMoveAssumeLegal(Location::getLoc(x, y, board.x_size), pla); + } + } + testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); + testAssert(0 == board.numLegalMoves); + + //runDotsStressTestsInternal(39, 32, 100000, 0.8f, 0.01f, 0.0f, true, false); + runDotsStressTestsInternal(39, 32, 10000, 0.8f, 0.01f, 0.0f, true, true); +} \ No newline at end of file diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp new file mode 100644 index 000000000..f2a2c4ba2 --- /dev/null +++ b/cpp/tests/testdotsextra.cpp @@ -0,0 +1,424 @@ +#include "../tests/tests.h" +#include "testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace TestCommon; + +void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardInput, const vector& extraMoves, const int symmetry) { + Board transformedBoard = SymmetryHelpers::getSymBoard(initBoard, symmetry); + Board expectedBoard = parseDotsFieldDefault(expectedSymmetryBoardInput); + for (const XYMove& extraMove : extraMoves) { + expectedBoard.playMoveAssumeLegal(SymmetryHelpers::getSymLoc(extraMove.x, extraMove.y, initBoard, symmetry), extraMove.player); + } + expect(SymmetryHelpers::symmetryToString(symmetry).c_str(), Board::toStringSimple(transformedBoard, '\n'), Board::toStringSimple(expectedBoard, '\n')); + testAssert(transformedBoard.isEqualForTesting(expectedBoard, true, true)); +} + +void Tests::runDotsSymmetryTests() { + cout << "Running dots symmetry tests" << endl; + + Board initialBoard = parseDotsFieldDefault(R"( +...ox +..ox. +.o.ox +.xo.. +)"); + initialBoard.playMoveAssumeLegal(Location::getLoc(4, 1, initialBoard.x_size), P_WHITE); + testAssert(1 == initialBoard.numBlackCaptures); + + checkSymmetry(initialBoard, R"( +...ox +..ox. +.o.ox +.xo.. +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_NONE); + + checkSymmetry(initialBoard, R"( +.xo.. +.o.ox +..ox. +...ox +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_FLIP_Y); + + checkSymmetry(initialBoard, R"( +xo... +.xo.. +xo.o. +..ox. +)", +{ XYMove(4, 1, P_WHITE)}, + SymmetryHelpers::SYMMETRY_FLIP_X); + + checkSymmetry(initialBoard, R"( +..ox. +xo.o. +.xo.. +xo... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_FLIP_Y_X); + + checkSymmetry(initialBoard, R"( +.... +..ox +.o.o +oxo. +x.x. +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE); + + checkSymmetry(initialBoard, R"( +.... +xo.. +o.o. +.oxo +.x.x +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_X); + + checkSymmetry(initialBoard, R"( +x.x. +oxo. +.o.o +..ox +.... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y); + + checkSymmetry(initialBoard, R"( +.x.x +.oxo +o.o. +xo.. +.... +)", +{ XYMove(4, 1, P_WHITE)}, +SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); +} + +string getGroundedTerritory(const string& boardData, const int expectedGroundingWhiteScore, const vector& extraMoves) { + Board board = parseDotsFieldDefault(boardData, extraMoves); + + Board copy(board); + Color result[Board::MAX_ARR_SIZE]; + int whiteScore = copy.calculateGroundingWhiteScore(result); + testAssert(expectedGroundingWhiteScore == whiteScore); + + std::ostringstream oss; + + for (int y = 0; y < copy.y_size; y++) { + for (int x = 0; x < copy.x_size; x++) { + Loc loc = Location::getLoc(x, y, copy.x_size); + oss << PlayerIO::colorToChar(result[loc]); + } + oss << endl; + } + + testAssert(board.isEqualForTesting(copy, true, true)); + return oss.str(); +} + +string getGroundedTerritory(const string& boardData, const int expectedGroundingWhiteScore) { + return getGroundedTerritory(boardData, expectedGroundingWhiteScore, vector()); +} + +void Tests::runDotsTerritoryTests() { + expect("Cross", getGroundedTerritory(R"( +...... +...... +..ox.. +..xo.. +...... +...... +)", 0), + R"( +...... +...... +...... +...... +...... +...... +)"); + + expect("Grounded white", getGroundedTerritory(R"( +..o... +..o... +..ox.. +..xo.. +...o.. +...o.. +)", 2), + R"( +..O... +..O... +..O... +...O.. +...O.. +...O.. +)"); + + expect("Grounded black", getGroundedTerritory(R"( +...x.. +...x.. +..ox.. +..xo.. +..x... +..x... +)", -2), + R"( +...X.. +...X.. +...X.. +..X... +..X... +..X... +)"); + + expect("Grounded white and black", getGroundedTerritory(R"( +..ox.. +..ox.. +..ox.. +..xo.. +..xo.. +..xo.. +)", 0), +R"( +..OX.. +..OX.. +..OX.. +..XO.. +..XO.. +..XO.. +)"); + + expect("Ungrounded white base", getGroundedTerritory(R"( +...... +...... +..ox.. +.oxo.. +...... +...... +)", -2, {XYMove(2, 4, P_WHITE)}), +R"( +...... +...... +...... +...... +...... +...... +)"); + + expect("Grounded white base", getGroundedTerritory(R"( +...... +...... +..ox.. +.oxo.. +...... +...... +)", 3, {XYMove(2, 4, P_WHITE), XYMove(2, 5, P_WHITE)}), +R"( +...... +...... +..O... +.OOO.. +..O... +..O... +)"); +} + +std::pair getCapturingAndBases( + const string& boardData, + const bool suicideLegal, + const vector& extraMoves +) { + Board board = parseDotsFieldDefault(boardData, extraMoves); + + Board copy(board); + + vector captures; + vector bases; + copy.calculateOneMoveCaptureAndBasePositionsForDots(suicideLegal, captures, bases); + + std::ostringstream capturesStringStream; + std::ostringstream basesStringStream; + + for (int y = 0; y < copy.y_size; y++) { + for (int x = 0; x < copy.x_size; x++) { + Loc loc = Location::getLoc(x, y, copy.x_size); + Color captureColor = captures[loc]; + if (captureColor == C_WALL) { + capturesStringStream << PlayerIO::colorToChar(P_BLACK) << PlayerIO::colorToChar(P_WHITE); + } else { + capturesStringStream << PlayerIO::colorToChar(captureColor) << " "; + } + + Color baseColor = bases[loc]; + if (baseColor == C_WALL) { + basesStringStream << PlayerIO::colorToChar(P_BLACK) << PlayerIO::colorToChar(P_WHITE); + } else { + basesStringStream << PlayerIO::colorToChar(baseColor) << " "; + } + + if (x < copy.x_size - 1) { + capturesStringStream << " "; + basesStringStream << " "; + } + } + capturesStringStream << endl; + basesStringStream << endl; + } + + // Make sure we didn't change an internal state during calculating + testAssert(board.isEqualForTesting(copy, true, true)); + + return std::pair(capturesStringStream.str(), basesStringStream.str()); +} + +void checkCapturingAndBase( + const string& title, + const string& boardData, + const bool suicideLegal, + const vector& extraMoves, + const string& expectedCaptures, + const string& expectedBases +) { + auto [capturing, bases] = getCapturingAndBases(boardData, suicideLegal, extraMoves); + cout << (" " + title + ": capturing").c_str() << endl; + expect("", capturing, expectedCaptures); + cout << (" " + title + ": bases").c_str() << endl; + expect("", bases, expectedBases); +} + +void Tests::runDotsCapturingTests() { + cout << "Running dots capturing tests" << endl; + + checkCapturingAndBase( + "Two bases", + R"( +.x...o. +xox.oxo +....... +)", true, {}, R"( +. . . . . . . +. . . . . . . +. X . . . O . +)", + R"( +. . . . . . . +. X . . . O . +. . . . . . . +)" +); + + checkCapturingAndBase( + "Overlapping capturing location", + R"( +.x. +xox +... +oxo +.o. +)", true, {}, R"( +. . . +. . . +. XO . +. . . +. . . +)", + R"( +. . . +. X . +. . . +. O . +. . . +)" +); + + checkCapturingAndBase( + "Empty base", + R"( +.x. +x.x +.x. +)", true, {}, R"( +. . . +. . . +. . . +)", +R"( +. . . +. X . +. . . +)" +); + + checkCapturingAndBase( +"Empty base no suicide", +R"( +.x. +x.x +.x. +)", false, {}, R"( +. . . +. . . +. . . +)", +R"( +. . . +. X . +. . . +)" +); + + checkCapturingAndBase( +"Empty base capturing", +R"( +.x. +x.x +... +)", true, {}, R"( +. . . +. . . +. X . +)", +R"( +. . . +. X . +. . . +)" +); + + checkCapturingAndBase( + "Complex example with overlapping of capturing and bases", + R"( +.ooxx. +o.xo.x +ox.ox. +ox.ox. +.o.x.. +)", true, {}, R"( +. . . . . . +. . . . . . +. . . . . . +. . XO . . . +. . XO . . . +)", + R"( +. . . . . . +. O O X X . +. O XO X . . +. O XO X . . +. . . . . . +)" +); +} \ No newline at end of file diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp new file mode 100644 index 000000000..2ff9fc68a --- /dev/null +++ b/cpp/tests/testdotsstartposes.cpp @@ -0,0 +1,263 @@ +#include + +#include "../tests/tests.h" +#include "testdotsutils.h" + +#include "../game/graphhash.h" +#include "../program/playutils.h" + +using namespace std; +using namespace std::chrono; +using namespace TestCommon; + +void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const Board& board) { + std::ostringstream sgfStringStream; + const BoardHistory boardHistory(board, P_BLACK, board.rules, 0); + WriteSgf::writeSgf(sgfStringStream, "black", "white", boardHistory, {}); + const string sgfString = sgfStringStream.str(); + cout << "; Sgf: " << sgfString << endl; + + const auto deserializedSgf = Sgf::parse(sgfString); + const Rules newRules = deserializedSgf->getRulesOrFail(); + testAssert(startPos == newRules.startPos); +} + +void checkStartPos(const string& description, const int startPos, const int x_size, const int y_size, const string& expectedBoard) { + cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; + + auto board = Board(x_size, y_size, Rules(true, startPos, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + + std::ostringstream oss; + Board::printBoard(oss, board, Board::NULL_LOC, nullptr); + + expect(description.c_str(), oss, expectedBoard); + + writeToSgfAndCheckStartPosFromSgfProp(startPos, board); +} + +void checkStartPosNotRecognized(const string& description, const string& inputBoard) { + const Board board = parseDotsFieldDefault(inputBoard); + + cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; + + writeToSgfAndCheckStartPosFromSgfProp(0, board); +} + +void Tests::runDotsStartPosTests() { + cout << "Running dots start pos tests" << endl; + + checkStartPos("Cross on minimal size", Rules::START_POS_CROSS, 2, 2, R"( +HASH: EC100709447890A116AFC8952423E3DD + 1 2 + 2 X O + 1 O X +)"); + + checkStartPosNotRecognized("Not enough dots for cross", R"( +.... +.xo. +.o.. +.... +)"); + + checkStartPosNotRecognized("Extra dots with cross", R"( +.... +.xo. +.ox. +..o. +)"); + + checkStartPosNotRecognized("Reversed cross shouldn't be recognized", R"( +.... +.ox. +.xo. +.... +)"); + + checkStartPos("Cross on odd size", Rules::START_POS_CROSS, 3, 3, R"( +HASH: 3B29F9557D2712A5BC982D218680927D + 1 2 3 + 3 . X O + 2 . O X + 1 . . . +)"); + + checkStartPos("Cross on standard size", Rules::START_POS_CROSS, 39, 32, R"( +HASH: 7881733B5E132EF52D2A74B8FEF83B01 + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +21 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . . X O . . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . . O X . . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +11 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); + + checkStartPos("Double cross on minimal size", Rules::START_POS_CROSS_2, 4, 2, R"( +HASH: 43FD769739F2AA27A8A1DAB1F4278229 + 1 2 3 4 + 2 X O O X + 1 O X X O +)"); + + checkStartPos("Double cross on odd size", Rules::START_POS_CROSS_2, 5, 3, R"( +HASH: AAA969B8135294A3D1ADAA07BEA9A987 + 1 2 3 4 5 + 3 . X O O X + 2 . O X X O + 1 . . . . . +)"); + + checkStartPos("Double cross", Rules::START_POS_CROSS_2, 6, 4, R"( +HASH: D599CEA39B1378D29883145CA4C016FC + 1 2 3 4 5 6 + 4 . . . . . . + 3 . X O O X . + 2 . O X X O . + 1 . . . . . . +)"); + + checkStartPos("Double cross", Rules::START_POS_CROSS_2, 7, 4, R"( +HASH: 249F175819EA8FDE47F8676E655A06DE + 1 2 3 4 5 6 7 + 4 . . . . . . . + 3 . . X O O X . + 2 . . O X X O . + 1 . . . . . . . +)"); + + checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, 39, 32, R"( +HASH: E3384654F950CB67E8EBDE0B86A2DAAD + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +21 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . X O O X . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . O X X O . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +11 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 5, 5, R"( +HASH: 0C2DD637AAE5FA7E1469BF5829BE922B + 1 2 3 4 5 + 5 X O . X O + 4 O X . O X + 3 . . . . . + 2 X O . X O + 1 O X . O X +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 7, 7, R"( +HASH: 89CBCA85E94AF1B6C376E6BCBC443A48 + 1 2 3 4 5 6 7 + 7 . . . . . . . + 6 . X O . X O . + 5 . O X . O X . + 4 . . . . . . . + 3 . X O . X O . + 2 . O X . O X . + 1 . . . . . . . +)"); + + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 8, 8, R"( +HASH: 445D50D7A61C47CE2730BBB97A2B3C96 + 1 2 3 4 5 6 7 8 + 8 . . . . . . . . + 7 . X O . . X O . + 6 . O X . . O X . + 5 . . . . . . . . + 4 . . . . . . . . + 3 . X O . . X O . + 2 . O X . . O X . + 1 . . . . . . . . +)"); + + checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, 39, 32, R"( +HASH: 03758E799934E32DFFAEA834AA838030 + 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 +32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +30 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +29 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +28 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +27 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +26 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +25 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +24 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +23 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +22 . . . . . . . . . . . X O . . . . . . . . . . . . . X O . . . . . . . . . . . +21 . . . . . . . . . . . O X . . . . . . . . . . . . . O X . . . . . . . . . . . +20 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +19 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +18 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +17 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +16 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +15 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +14 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +13 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +12 . . . . . . . . . . . X O . . . . . . . . . . . . . X O . . . . . . . . . . . +11 . . . . . . . . . . . O X . . . . . . . . . . . . . O X . . . . . . . . . . . +10 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 9 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . + 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . +)"); +} \ No newline at end of file diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp new file mode 100644 index 000000000..2f5948a3a --- /dev/null +++ b/cpp/tests/testdotsutils.cpp @@ -0,0 +1,39 @@ +#include "testdotsutils.h" + +using namespace std; + +Board parseDotsFieldDefault(const string& input) { + return parseDotsField(input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, vector()); +} + +Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { + return parseDotsField(input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); +} + +Board parseDotsField(const string& input, const bool captureEmptyBases, + const bool freeCapturedDots, const vector& extraMoves) { + int currentXSize = 0; + int xSize = -1; + int ySize = 0; + for(int i = 0; i <= input.length(); i++) { + if(i == input.length() - 1 || input[i] == '\n') { + if(i > 0) { + if(xSize != -1) { + assert(xSize == currentXSize); + } else { + xSize = currentXSize; + } + currentXSize = 0; + ySize++; + } + } else { + currentXSize++; + } + } + + Board result = Board::parseBoard(xSize, ySize, input, Rules(true, Rules::START_POS_EMPTY, captureEmptyBases, freeCapturedDots)); + for(const XYMove& extraMove : extraMoves) { + result.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, result.x_size), extraMove.player); + } + return result; +} \ No newline at end of file diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h new file mode 100644 index 000000000..d147e7b0b --- /dev/null +++ b/cpp/tests/testdotsutils.h @@ -0,0 +1,60 @@ +#pragma once + +#include "../program/playutils.h" + +using namespace std; + +struct XYMove { + int x; + int y; + Player player; + + XYMove(const int x, const int y, const Player player) : x(x), y(y), player(player) {} + + [[nodiscard]] std::string toString() const { + return "(" + to_string(x) + "," + to_string(y) + "," + PlayerIO::colorToChar(player) + ")"; + } +}; + +struct BoardWithMoveRecords { + Board& board; + vector& moveRecords; + + BoardWithMoveRecords(Board& initBoard, vector& initMoveRecords) : board(initBoard), moveRecords(initMoveRecords) {} + + void playMove(const int x, const int y, const Player player) const { + moveRecords.push_back(board.playMoveRecorded(Location::getLoc(x, y, board.x_size), player)); + } + + void playGroundingMove(const Player player) const { + moveRecords.push_back(board.playMoveRecorded(Board::PASS_LOC, player)); + } + + [[nodiscard]] State getState(const int x, const int y) const { + return board.getState(Location::getLoc(x, y, board.x_size)); + } + + [[nodiscard]] bool isLegal(const int x, const int y, const Player player) const { + return board.isLegal(Location::getLoc(x, y, board.x_size), player, true, false); + } + + [[nodiscard]] bool isSuicide(const int x, const int y, const Player player) const { + return board.isSuicide(Location::getLoc(x, y, board.x_size), player); + } + + [[nodiscard]] bool wouldBeCapture(const int x, const int y, const Player player) const { + return board.wouldBeCapture(Location::getLoc(x, y, board.x_size), player); + } + + void undo() const { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } +}; + +Board parseDotsFieldDefault(const string& input); + +Board parseDotsFieldDefault(const string& input, const vector& extraMoves); + +Board parseDotsField(const string& input, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); + diff --git a/cpp/tests/testnninputs.cpp b/cpp/tests/testnninputs.cpp index 1855d8889..47685f3f7 100644 --- a/cpp/tests/testnninputs.cpp +++ b/cpp/tests/testnninputs.cpp @@ -14,20 +14,7 @@ static void printNNInputHWAndBoard( ostream& out, int inputsVersion, const Board& board, const BoardHistory& hist, int nnXLen, int nnYLen, bool inputsUseNHWC, T* row, int c ) { - int numFeatures; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V3; - else if(inputsVersion == 4) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V4; - else if(inputsVersion == 5) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V5; - else if(inputsVersion == 6) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V6; - else if(inputsVersion == 7) - numFeatures = NNInputs::NUM_FEATURES_SPATIAL_V7; - else - testAssert(false); + const int numFeatures = NNInputs::getNumberOfSpatialFeatures(inputsVersion); out << "Channel: " << c << endl; @@ -67,20 +54,7 @@ static void printNNInputHWAndBoard( template static void printNNInputGlobal(ostream& out, int inputsVersion, T* row, int c) { - int numFeatures; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(inputsVersion == 3) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V3; - else if(inputsVersion == 4) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V4; - else if(inputsVersion == 5) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V5; - else if(inputsVersion == 6) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V6; - else if(inputsVersion == 7) - numFeatures = NNInputs::NUM_FEATURES_GLOBAL_V7; - else - testAssert(false); + const int numFeatures = NNInputs::getNumberOfGlobalFeatures(inputsVersion); (void)numFeatures; out << "Channel: " << c; @@ -117,6 +91,14 @@ static double finalScoreIfGameEndedNow(const BoardHistory& baseHist, const Board //================================================================================================================== //================================================================================================================== +Hash128 fillRowAndGetHash( + int version, + Board& board, const BoardHistory& hist, Player nextPla, MiscNNInputParams nnInputParams, int nnXLen, int nnYLen, bool inputsUseNHWC,float* rowBin, float* rowGlobal + ) { + const Hash128 hash = NNInputs::getHash(board,hist,nextPla,nnInputParams); + NNInputs::fillRowVN(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + return hash; +} void Tests::runNNInputsV3V4Tests() { cout << "Running NN inputs V3V4V5V6 tests" << endl; @@ -124,62 +106,13 @@ void Tests::runNNInputsV3V4Tests() { out << std::setprecision(5); auto allocateRows = [](int version, int nnXLen, int nnYLen, int& numFeaturesBin, int& numFeaturesGlobal, float*& rowBin, float*& rowGlobal) { - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(version == 3) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V3; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V3; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V3 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V3]; - } - else if(version == 4) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V4; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V4; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V4 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V4]; - } - else if(version == 5) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V5; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V5; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V5 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V5]; - } - else if(version == 6) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V6; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V6; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V6 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V6]; - } - else if(version == 7) { - numFeaturesBin = NNInputs::NUM_FEATURES_SPATIAL_V7; - numFeaturesGlobal = NNInputs::NUM_FEATURES_GLOBAL_V7; - rowBin = new float[NNInputs::NUM_FEATURES_SPATIAL_V7 * nnXLen * nnYLen]; - rowGlobal = new float[NNInputs::NUM_FEATURES_GLOBAL_V7]; - } - else - testAssert(false); - }; - - auto fillRows = [](int version, Hash128& hash, - Board& board, const BoardHistory& hist, Player nextPla, MiscNNInputParams nnInputParams, int nnXLen, int nnYLen, bool inputsUseNHWC, - float* rowBin, float* rowGlobal) { - hash = NNInputs::getHash(board,hist,nextPla,nnInputParams); - - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); - if(version == 3) - NNInputs::fillRowV3(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 4) - NNInputs::fillRowV4(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 5) - NNInputs::fillRowV5(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 6) - NNInputs::fillRowV6(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else if(version == 7) - NNInputs::fillRowV7(board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); - else - testAssert(false); + numFeaturesBin = NNInputs::getNumberOfSpatialFeatures(version); + numFeaturesGlobal = NNInputs::getNumberOfGlobalFeatures(version); + rowBin = new float[numFeaturesBin * nnXLen * nnYLen]; + rowGlobal = new float[numFeaturesGlobal]; }; - static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); + static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); int minVersion = 3; int maxVersion = 7; @@ -219,10 +152,11 @@ void Tests::runNNInputsV3V4Tests() { allocateRows(version,nnXLen,nnYLen,numFeaturesBin,numFeaturesGlobal,rowBin,rowGlobal); auto run = [&](bool inputsUseNHWC) { - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + + const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + out << hash << endl; for(int c = 0; c 0 ? getOpp(hist.moveHistory[hist.moveHistory.size()-1].pla) : hist.initialPla; - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); out << hash << endl; printNNInputGlobal(out,version,rowGlobal,5); int c = 18; @@ -893,10 +819,9 @@ o.xoo.x BoardHistory hist(board,nextPla,initialRules,0); auto run = [&](bool inputsUseNHWC) { - Hash128 hash; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,9); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,10); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,11); @@ -960,11 +885,10 @@ o.xoo.x auto test = [&](const Board& board, const BoardHistory& hist, Player nextPla) { bool inputsUseNHWC = true; - Hash128 hash; Board b = board; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - fillRows(version,hash,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + const Hash128 hash = fillRowAndGetHash(version,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); for(int c = 0; cspawnServerThreads(); diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index 4492fb00e..ef0c8d3ef 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -1,7 +1,6 @@ #include "../tests/tests.h" #include -#include #include "../dataio/sgf.h" #include "../search/asyncbot.h" @@ -16,22 +15,29 @@ void Tests::runSgfTests() { auto parseAndPrintSgfLinear = [&out](const string& sgfStr) { std::unique_ptr sgf = CompactSgf::parse(sgfStr); + if (sgf->isDots) { + out << "Dots game" << endl; + } out << "xSize " << sgf->xSize << endl; out << "ySize " << sgf->ySize << endl; out << "depth " << sgf->depth << endl; - out << "komi " << sgf->getRulesOrFailAllowUnspecified(Rules()).komi << endl; + Rules rules = sgf->getRulesOrFailAllowUnspecified(Rules::getDefaultOrTrompTaylorish(sgf->isDots)); + out << "komi " << rules.komi << endl; Board board; BoardHistory hist; - Rules rules; Player pla; - rules = sgf->getRulesOrFailAllowUnspecified(rules); + sgf->setupInitialBoardAndHist(rules,board,pla,hist); - out << "placements" << endl; - for(int i = 0; i < sgf->placements.size(); i++) { - Move move = sgf->placements[i]; - out << PlayerIO::colorToChar(move.pla) << " " << Location::toString(move.loc,board) << endl; + if (rules.startPos == Rules::START_POS_EMPTY) { + out << "placements" << endl; + for(int i = 0; i < sgf->placements.size(); i++) { + Move move = sgf->placements[i]; + out << PlayerIO::colorToChar(move.pla) << " " << Location::toString(move.loc,board) << endl; + } + } else { + out << "startPos " << Rules::writeStartPosRule(rules.startPos) << endl; } out << "moves" << endl; for(int i = 0; i < sgf->moves.size(); i++) { @@ -119,7 +125,7 @@ void Tests::runSgfTests() { sgf->getPlacements(placements, xySize.x, xySize.y); out << "placements " << placements.size() << endl; for(const Move& move: placements) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y) << " "; + out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -127,7 +133,7 @@ void Tests::runSgfTests() { sgf->getMoves(moves, xySize.x, xySize.y); out << "moves " << moves.size() << endl; for(const Move& move: moves) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y) << " "; + out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -151,6 +157,78 @@ void Tests::runSgfTests() { ); }; + { + const char* name = "Basic Dots Sgf parse test"; + string sgfStr = "(;FF[4]GM[40]CA[UTF-8]AP[katago]SZ[10:8]AB[ed][fe]AW[ee][fd];B[ef];W[de];B[df];W[hd];B[ce];W[hf];B[cd];W[ff];B[dc];W[cf];B[hb];W[ic];B[db];W[gg];B[da];W[bg];B[])"; + + parseAndPrintSgfLinear(sgfStr); + string expected = R"( +Dots game +xSize 10 +ySize 8 +depth 18 +komi 0 +startPos CROSS +moves +X E3 +O D4 +X D3 +O H5 +X C4 +O H3 +X C5 +O F3 +X D6 +O C3 +X H7 +O J6 +X D7 +O G2 +X D8 +O B2 +X ground +Initial board hist +pla Black +HASH: BA8A444F3D6E9FC94A3F4A16C7D2DBA0 + 1 2 3 4 5 6 7 8 9 10 + 8 . . . . . . . . . . + 7 . . . . . . . . . . + 6 . . . . . . . . . . + 5 . . . . X O . . . . + 4 . . . . O X . . . . + 3 . . . . . . . . . . + 2 . . . . . . . . . . + 1 . . . . . . . . . . + + +Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 +White bonus score 0 +Presumed next pla Black +Game result 0 Empty 0 0 0 0 +Last moves +Final board hist +pla White +HASH: AB87C4395AA2D7E5D7B069ACBFA701D5 + 1 2 3 4 5 6 7 8 9 10 + 8 . . . X . . . . . . + 7 . . . X . . . O . . + 6 . . . X . . . . O . + 5 . . X X X O . O . . + 4 . . X X X X . . . . + 3 . . O X X O . O . . + 2 . O . . . . O . . . + 1 . . . . . . . . . . + + +Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 +White bonus score 0 +Presumed next pla White +Game result 1 Black -1 1 0 0 +Last moves E3 D4 D3 H5 C4 H3 C5 F3 D6 C3 H7 J6 D7 G2 D8 B2 ground +)"; + expect(name,out,expected); + } + //============================================================================ { const char* name = "Basic Sgf parse test"; @@ -994,7 +1072,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koSIMPLEscoreAREAtaxNONEsui0whbNkomi0 +Rules koSIMPLEscoreAREAtaxNONEsui0whbNfpok1komi0 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -1051,7 +1129,7 @@ Encore phase 0 Turns this phase 13 Approx valid turns this phase 13 Approx consec valid turns this game 13 -Rules koSIMPLEscoreAREAtaxNONEsui0whbNkomi0 +Rules koSIMPLEscoreAREAtaxNONEsui0whbNfpok1komi0 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -1421,7 +1499,7 @@ void Tests::runSgfFileTests() { testAssert(sgf->getXYSize().y == 19); testAssert(sgf->getKomiOrFail() == 6.5f); testAssert(sgf->hasRules() == true); - testAssert(sgf->getRulesOrFail().equalsIgnoringKomi(Rules::parseRules("chinese"))); + testAssert(sgf->getRulesOrFail().equalsIgnoringSgfDefinedProps(Rules::parseRules("chinese"))); testAssert(sgf->getHandicapValue() == 2); testAssert(sgf->getSgfWinner() == C_EMPTY); testAssert(sgf->getPlayerName(P_BLACK) == "testname1"); diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index 7c9e3292f..edc3b4d32 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -51,7 +51,8 @@ static NNEvaluator* startNNEval( gpuIdxByServerThread, seed, nnRandomize, - defaultSymmetry + defaultSymmetry, + false // TODO: Fix for Dots Game ); nnEval->spawnServerThreads(); From f006fb47c15ec716bacfe4f311e960dc1f0087e3 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 13 Dec 2025 19:36:30 +0100 Subject: [PATCH 03/42] Remove `inline` modifier to fix linking issues --- cpp/game/dotsfield.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index bf4e30ff1..4916e8c1c 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -47,7 +47,7 @@ Loc Location::xm1yp1(Loc loc, int x_size) { return loc - 1 + (x_size+1); } -inline int Location::getGetBigJumpInitialIndex(const Loc loc0, const Loc loc1, const int x_size) { +int Location::getGetBigJumpInitialIndex(const Loc loc0, const Loc loc1, const int x_size) { const int diff = loc1 - loc0; const int stride = x_size + 1; @@ -70,7 +70,7 @@ inline int Location::getGetBigJumpInitialIndex(const Loc loc0, const Loc loc1, c return -1; } -inline Loc Location::getNextLocCW(const Loc loc0, const Loc loc1, const int x_size) { +Loc Location::getNextLocCW(const Loc loc0, const Loc loc1, const int x_size) { const int diff = loc1 - loc0; const int stride = x_size + 1; @@ -91,11 +91,11 @@ Color getActiveColor(const State state) { return static_cast(state & ACTIVE_MASK); } -inline bool isVisited(const State s) { +bool isVisited(const State s) { return (s & VISITED_FLAG) == VISITED_FLAG; } -inline bool isTerritory(const State s) { +bool isTerritory(const State s) { return (s & TERRITORY_FLAG) == TERRITORY_FLAG; } @@ -103,31 +103,31 @@ Color getPlacedDotColor(const State s) { return static_cast(s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK); } -inline bool isPlaced(const State s, const Player pla) { +bool isPlaced(const State s, const Player pla) { return (s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK) == pla; } -inline bool isActive(const State s, const Player pla) { +bool isActive(const State s, const Player pla) { return (s & ACTIVE_MASK) == pla; } -inline State setTerritoryAndActivePlayer(const State s, const Player pla) { +State setTerritoryAndActivePlayer(const State s, const Player pla) { return static_cast(TERRITORY_FLAG | (s & INVALIDATE_TERRITORY_MASK | pla)); } -inline Color getEmptyTerritoryColor(const State s) { +Color getEmptyTerritoryColor(const State s) { return static_cast(s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK); } -inline bool isWithinEmptyTerritory(const State s, const Player pla) { +bool isWithinEmptyTerritory(const State s, const Player pla) { return (s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK) == pla; } -inline State Board::getState(const Loc loc) const { +State Board::getState(const Loc loc) const { return colors[loc]; } -inline void Board::setState(const Loc loc, const State state) { +void Board::setState(const Loc loc, const State state) { colors[loc] = state; } @@ -135,15 +135,15 @@ bool Board::isDots() const { return rules.isDots; } -inline void Board::setVisited(const Loc loc) { +void Board::setVisited(const Loc loc) { colors[loc] = static_cast(colors[loc] | VISITED_FLAG); } -inline void Board::clearVisited(const Loc loc) { +void Board::clearVisited(const Loc loc) { colors[loc] = static_cast(colors[loc] & INVALIDATE_VISITED_MASK); } -inline void Board::clearVisited(const vector& locations) { +void Board::clearVisited(const vector& locations) { for (const Loc& loc : locations) { clearVisited(loc); } @@ -426,7 +426,7 @@ void Board::getUnconnectedLocations(const Loc loc, const Player pla) const { checkAndAddUnconnectedLocation(getColor(xym1), pla, loc + adj_offsets[RIGHT_TOP_INDEX], xp1y); } -inline void Board::checkAndAddUnconnectedLocation(const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { +void Board::checkAndAddUnconnectedLocation(const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { if (checkPla != currentPla) { if (getColor(addLoc1) == currentPla) { unconnectedLocationsBuffer[unconnectedLocationsBufferSize++] = addLoc1; From 57b465bb2690434ea108d73da2187d455633fe64 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:06:38 +0100 Subject: [PATCH 04/42] Split `MAX_LEN` into `MAX_LEN_X` and `MAX_LEN_Y` It allows handling field sizes for accurately that are actual for Dots game because its default size is not a square. --- cpp/CMakeLists.txt | 15 ++++++++++++--- cpp/command/analysis.cpp | 15 ++++++++------- cpp/command/contribute.cpp | 4 ++-- cpp/command/genbook.cpp | 8 ++++---- cpp/command/gtp.cpp | 16 ++++++++++------ cpp/command/match.cpp | 2 +- cpp/command/runtests.cpp | 6 +++--- cpp/command/writetrainingdata.cpp | 4 ++-- cpp/core/config_parser.cpp | 6 +++--- cpp/core/config_parser.h | 2 +- cpp/dataio/sgf.cpp | 9 +++++++-- cpp/game/board.cpp | 18 ++++++++++-------- cpp/game/board.h | 25 +++++++++++++++++++------ cpp/neuralnet/nneval.cpp | 8 ++++---- cpp/neuralnet/nninputs.h | 4 +++- cpp/neuralnet/nninputsdots.cpp | 4 ++-- cpp/neuralnet/openclbackend.cpp | 12 ++++++------ cpp/neuralnet/opencltuner.cpp | 8 ++++---- cpp/neuralnet/opencltuner.h | 4 ++-- cpp/program/play.cpp | 2 +- cpp/program/setup.cpp | 26 +++++++++++++------------- cpp/search/search.cpp | 6 +++--- cpp/tests/testdotsstartposes.cpp | 6 +++--- cpp/tests/testnnevalcanary.cpp | 4 ++-- cpp/tests/tests.h | 2 +- cpp/tests/testsearch.cpp | 2 +- cpp/tests/testsearchnonn.cpp | 14 +++++++------- cpp/tests/testsearchv3.cpp | 4 ++-- cpp/tests/testsearchv9.cpp | 2 +- cpp/tests/testsgf.cpp | 2 +- cpp/tests/testtrainingwrite.cpp | 4 ++-- cpp/tests/tinymodel.cpp | 6 +++--- 32 files changed, 143 insertions(+), 107 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bd2eb4520..4e0d41c2f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -38,6 +38,7 @@ set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") set(USE_AVX2 0 CACHE BOOL "Compile with AVX2") set(USE_BIGGER_BOARDS_EXPENSIVE 0 CACHE BOOL "Allow boards up to size 50. Compiling with this will use more memory and slow down KataGo, even when playing on boards of size 19.") +set(DOTS_GAME 0 CACHE BOOL "Configure KataGo to compile for Dots Game with possibility to run Go as well. It configures more optimal field sizes") set(USE_CACHE_TENSORRT_PLAN 0 CACHE BOOL "Use TENSORRT plan cache. May use a lot of disk space. Only applies when USE_BACKEND is TENSORRT.") mark_as_advanced(USE_CACHE_TENSORRT_PLAN) @@ -453,9 +454,17 @@ elseif(USE_BACKEND STREQUAL "EIGEN") endif() endif() -if(USE_BIGGER_BOARDS_EXPENSIVE) - target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN=50) -endif() +if(DOTS_GAME) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_X=39) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_Y=36) + if(USE_BIGGER_BOARDS_EXPENSIVE) + message(SEND_ERROR "USE_BIGGER_BOARDS_EXPENSIVE is not yet supported for Dots Game") + endif() +else() + if(USE_BIGGER_BOARDS_EXPENSIVE) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN=50) + endif() +endif(DOTS_GAME) if(NO_GIT_REVISION AND (NOT BUILD_DISTRIBUTED)) target_compile_definitions(katago PRIVATE NO_GIT_REVISION) diff --git a/cpp/command/analysis.cpp b/cpp/command/analysis.cpp index 696d28cad..a6b5b6c30 100644 --- a/cpp/command/analysis.cpp +++ b/cpp/command/analysis.cpp @@ -161,13 +161,13 @@ int MainCmds::analysis(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); if(humanModelFile != "") { humanEval = Setup::initializeNNEvaluator( humanModelFile,humanModelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); if(!humanEval->requiresSGFMetadata()) { @@ -702,19 +702,20 @@ int MainCmds::analysis(const vector& args) { { int64_t xBuf; int64_t yBuf; - static const string boardSizeError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN); + static const string boardSizeXError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN_X); + static const string boardSizeYError = string("Must provide an integer from 2 to ") + Global::intToString(Board::MAX_LEN_Y); if(input.find("boardXSize") == input.end()) { - reportErrorForId(rbase.id, "boardXSize", boardSizeError.c_str()); + reportErrorForId(rbase.id, "boardXSize", boardSizeXError); continue; } if(input.find("boardYSize") == input.end()) { - reportErrorForId(rbase.id, "boardYSize", boardSizeError.c_str()); + reportErrorForId(rbase.id, "boardYSize", boardSizeYError); continue; } - if(!parseInteger(input, "boardXSize", xBuf, 2, Board::MAX_LEN, boardSizeError.c_str())) { + if(!parseInteger(input, "boardXSize", xBuf, 2, Board::MAX_LEN_X, boardSizeXError.c_str())) { continue; } - if(!parseInteger(input, "boardYSize", yBuf, 2, Board::MAX_LEN, boardSizeError.c_str())) { + if(!parseInteger(input, "boardYSize", yBuf, 2, Board::MAX_LEN_Y, boardSizeYError.c_str())) { continue; } boardXSize = (int)xBuf; diff --git a/cpp/command/contribute.cpp b/cpp/command/contribute.cpp index 62fd7bf42..7cd5a0a96 100644 --- a/cpp/command/contribute.cpp +++ b/cpp/command/contribute.cpp @@ -905,7 +905,7 @@ int MainCmds::contribute(const vector& args) { const bool disableFP16 = false; nnEval = Setup::initializeNNEvaluator( modelName,modelFile,modelInfo.sha256,*userCfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); assert(!nnEval->isNeuralNetLess() || modelFile == "/dev/null"); @@ -919,7 +919,7 @@ int MainCmds::contribute(const vector& args) { const bool disableFP16 = true; nnEval32 = Setup::initializeNNEvaluator( modelName,modelFile,modelInfo.sha256,*userCfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); } diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index ff04ffec0..69c10ee10 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -329,8 +329,8 @@ int MainCmds::genbook(const vector& args) { const bool hasHumanModel = humanModelFile != ""; const SearchParams params = Setup::loadSingleParams(cfg,Setup::SETUP_FOR_GTP,hasHumanModel); - const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN); - const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN); + const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN_X); + const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN_Y); const int repBound = cfg.getInt("repBound",3,1000); const double bonusFileScale = cfg.contains("bonusFileScale") ? cfg.getDouble("bonusFileScale",0.0,1000000.0) : 1.0; @@ -1526,8 +1526,8 @@ int MainCmds::writebook(const vector& args) { const bool loadKomiFromCfg = true; Rules rules = Setup::loadSingleRules(cfg,loadKomiFromCfg); - const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN); - const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN); + const int boardSizeX = cfg.getInt("boardSizeX",2,Board::MAX_LEN_X); + const int boardSizeY = cfg.getInt("boardSizeY",2,Board::MAX_LEN_Y); const int repBound = cfg.getInt("repBound",3,1000); std::map bonusByHash; diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 297fb2d89..a11b73128 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -458,8 +458,8 @@ struct GTPEngine { void setOrResetBoardSize(ConfigParser& cfg, Logger& logger, Rand& seedRand, int boardXSize, int boardYSize, bool loggingToStderr) { bool wasDefault = false; if(boardXSize == -1 || boardYSize == -1) { - boardXSize = Board::DEFAULT_LEN; - boardYSize = Board::DEFAULT_LEN; + boardXSize = Board::DEFAULT_LEN_X; + boardYSize = Board::DEFAULT_LEN_Y; wasDefault = true; } @@ -469,8 +469,8 @@ struct GTPEngine { if(cfg.contains("gtpForceMaxNNSize") && cfg.getBool("gtpForceMaxNNSize")) { defaultRequireExactNNLen = false; - nnXLen = Board::MAX_LEN; - nnYLen = Board::MAX_LEN; + nnXLen = Board::MAX_LEN_X; + nnYLen = Board::MAX_LEN_Y; } //If the neural net is wrongly sized, we need to create or recreate it @@ -2301,9 +2301,13 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "unacceptable size"; } - else if(newXSize > Board::MAX_LEN || newYSize > Board::MAX_LEN) { + else if(newXSize > Board::MAX_LEN_X) { responseIsError = true; - response = Global::strprintf("unacceptable size (Board::MAX_LEN is %d, consider increasing and recompiling)",(int)Board::MAX_LEN); + response = Global::strprintf("unacceptable size (Board::MAX_LEN_X is %d, consider increasing and recompiling)",(int)Board::MAX_LEN_X); + } + else if(newYSize > Board::MAX_LEN_Y) { + responseIsError = true; + response = Global::strprintf("unacceptable size (Board::MAX_LEN_Y is %d, consider increasing and recompiling)",(int)Board::MAX_LEN_Y); } else { engine->setOrResetBoardSize(cfg,logger,seedRand,newXSize,newYSize,logger.isLoggingToStderr()); diff --git a/cpp/command/match.cpp b/cpp/command/match.cpp index 9d7fbda90..d235634b4 100644 --- a/cpp/command/match.cpp +++ b/cpp/command/match.cpp @@ -108,7 +108,7 @@ int MainCmds::match(const vector& args) { } if(cfg.contains("extraPairs")) { - std::vector> pairs = cfg.getNonNegativeIntDashedPairs("extraPairs",0,numBots-1); + std::vector> pairs = cfg.getNonNegativeIntDashedPairs("extraPairs", 0, numBots - 1, numBots - 1); for(const std::pair& pair: pairs) { int p0 = pair.first; int p1 = pair.second; diff --git a/cpp/command/runtests.cpp b/cpp/command/runtests.cpp index 9cdeb3516..26d154cd4 100644 --- a/cpp/command/runtests.cpp +++ b/cpp/command/runtests.cpp @@ -445,7 +445,7 @@ int MainCmds::runnnevalcanarytests(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } @@ -504,7 +504,7 @@ int MainCmds::runbeginsearchspeedtest(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - Board::MAX_LEN,Board::MAX_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + Board::MAX_LEN_X,Board::MAX_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } @@ -628,7 +628,7 @@ int MainCmds::runownershipspeedtest(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( modelFile,modelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - Board::MAX_LEN,Board::MAX_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + Board::MAX_LEN_X,Board::MAX_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_GTP ); } diff --git a/cpp/command/writetrainingdata.cpp b/cpp/command/writetrainingdata.cpp index fab3510ac..5892aded3 100644 --- a/cpp/command/writetrainingdata.cpp +++ b/cpp/command/writetrainingdata.cpp @@ -684,7 +684,7 @@ int MainCmds::writetrainingdata(const vector& args) { const int maxApproxRowsPerTrainFile = cfg.getInt("maxApproxRowsPerTrainFile",1,100000000); const std::vector> allowedBoardSizes = - cfg.getNonNegativeIntDashedPairs("allowedBoardSizes", 2, Board::MAX_LEN); + cfg.getNonNegativeIntDashedPairs("allowedBoardSizes", 2, Board::MAX_LEN_X, Board::MAX_LEN_Y); if(dataBoardLen > Board::MAX_LEN) throw StringError("dataBoardLen > maximum board len, must recompile to increase"); @@ -712,7 +712,7 @@ int MainCmds::writetrainingdata(const vector& args) { const string expectedSha256 = ""; nnEval = Setup::initializeNNEvaluator( nnModelFile,nnModelFile,expectedSha256,cfg,logger,seedRand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,defaultMaxBatchSize,defaultRequireExactNNLen,disableFP16, Setup::SETUP_FOR_ANALYSIS ); } diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index e21d274b7..3d8f60cf7 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -649,7 +649,7 @@ vector ConfigParser::getInts(const string& key, int min, int max) { } return ret; } -vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, int min, int max) { +vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, int min, int max1, int max2) { std::vector pairStrs = getStrings(key); std::vector> ret; for(const string& pairStr: pairStrs) { @@ -670,8 +670,8 @@ vector> ConfigParser::getNonNegativeIntDashedPairs(const stri if(!suc) throw IOError("Could not parse '" + pairStr + "' as a pair of integers separated by a dash for key '" + key + "' in config file " + fileName); - if(p0 < min || p0 > max || p1 < min || p1 > max) - throw IOError("Expected key '" + key + "' to have all values range " + Global::intToString(min) + " to " + Global::intToString(max) + " in config file " + fileName); + if(p0 < min || p0 > max1 || p1 < min || p1 > max2) + throw IOError("Expected key '" + key + "' to have all values range " + Global::intToString(min) + " to (" + Global::intToString(max1) + ", " + Global::intToString(max2) + ") in config file " + fileName); ret.push_back(std::make_pair(p0,p1)); } diff --git a/cpp/core/config_parser.h b/cpp/core/config_parser.h index 8d0a9191a..82e68b49c 100644 --- a/cpp/core/config_parser.h +++ b/cpp/core/config_parser.h @@ -90,7 +90,7 @@ class ConfigParser { std::vector getFloats(const std::string& key, float min, float max); std::vector getDoubles(const std::string& key, double min, double max); - std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max); + std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max1, int max2); private: bool initialized; diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 5039f184b..95a442fd9 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -450,9 +450,14 @@ XYSize Sgf::getXYSize() const { if(xSize <= 1 || ySize <= 1) propertyFail("Board size in sgf is <= 1: " + s); - if(xSize > Board::MAX_LEN || ySize > Board::MAX_LEN) + if(xSize > Board::MAX_LEN_X) propertyFail( - "Board size in sgf is > Board::MAX_LEN = " + Global::intToString((int)Board::MAX_LEN) + + "Board size x in sgf is > Board::MAX_LEN_X = " + Global::intToString((int)Board::MAX_LEN_X) + + ", if larger sizes are desired, consider increasing and recompiling: " + s + ); + if(ySize > Board::MAX_LEN_Y) + propertyFail( + "Board size y in sgf is > Board::MAX_LEN_Y = " + Global::intToString((int)Board::MAX_LEN_Y) + ", if larger sizes are desired, consider increasing and recompiling: " + s ); return XYSize(xSize,ySize); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 5216a4976..d2cdbd809 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -18,8 +18,8 @@ using namespace std; //STATIC VARS----------------------------------------------------------------------------- bool Board::IS_ZOBRIST_INITALIZED = false; -Hash128 Board::ZOBRIST_SIZE_X_HASH[MAX_LEN+1]; -Hash128 Board::ZOBRIST_SIZE_Y_HASH[MAX_LEN+1]; +Hash128 Board::ZOBRIST_SIZE_X_HASH[MAX_LEN_X+1]; +Hash128 Board::ZOBRIST_SIZE_Y_HASH[MAX_LEN_Y+1]; Hash128 Board::ZOBRIST_BOARD_HASH[MAX_ARR_SIZE][4]; Hash128 Board::ZOBRIST_PLAYER_HASH[4]; Hash128 Board::ZOBRIST_KO_LOC_HASH[MAX_ARR_SIZE]; @@ -121,7 +121,7 @@ Board::Base::Base(Player newPla, Board::Board() { - init(DEFAULT_LEN, DEFAULT_LEN, Rules()); + init(DEFAULT_LEN_X, DEFAULT_LEN_Y, Rules()); } Board::Board(int x, int y) @@ -158,7 +158,7 @@ Board::Board(const Board& other) { void Board::init(const int xS, const int yS, const Rules& initRules) { assert(IS_ZOBRIST_INITALIZED); - if(xS < 0 || yS < 0 || xS > MAX_LEN || yS > MAX_LEN) + if(xS < 0 || yS < 0 || xS > MAX_LEN_X || yS > MAX_LEN_Y) throw StringError("Board::init - invalid board size"); x_size = xS; @@ -247,9 +247,11 @@ void Board::initHash() //Reseed the random number generator so that these size hashes are also //not affected by the size of the board we compile with rand.init("Board::initHash() for ZOBRIST_SIZE hashes"); - for(int i = 0; i NNPos::MAX_BOARD_LEN) - throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); - if(nnYLen > NNPos::MAX_BOARD_LEN) - throw StringError("Maximum supported nnEval board size is " + Global::intToString(NNPos::MAX_BOARD_LEN)); + if(nnXLen > NNPos::MAX_BOARD_LEN_X) + throw StringError("Maximum supported nnEval board size x is " + Global::intToString(NNPos::MAX_BOARD_LEN_X)); + if(nnYLen > NNPos::MAX_BOARD_LEN_Y) + throw StringError("Maximum supported nnEval board size y is " + Global::intToString(NNPos::MAX_BOARD_LEN_Y)); if(maxBatchSize <= 0) throw StringError("maxBatchSize is negative: " + Global::intToString(maxBatchSize)); if(gpuIdxByServerThread.size() != numThreads) diff --git a/cpp/neuralnet/nninputs.h b/cpp/neuralnet/nninputs.h index 47394dd8b..a92a1edf4 100644 --- a/cpp/neuralnet/nninputs.h +++ b/cpp/neuralnet/nninputs.h @@ -13,8 +13,10 @@ void setRowBin(float* rowBin, int pos, int feature, float value, int posStride, int featureStride); namespace NNPos { + constexpr int MAX_BOARD_LEN_X = Board::MAX_LEN_X; + constexpr int MAX_BOARD_LEN_Y = Board::MAX_LEN_Y; constexpr int MAX_BOARD_LEN = Board::MAX_LEN; - constexpr int MAX_BOARD_AREA = MAX_BOARD_LEN * MAX_BOARD_LEN; + constexpr int MAX_BOARD_AREA = MAX_BOARD_LEN_X * MAX_BOARD_LEN_Y; //Policy output adds +1 for the pass move constexpr int MAX_NN_POLICY_SIZE = MAX_BOARD_AREA + 1; //Extra score distribution radius, used for writing score in data rows and for the neural net score belief output diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index 62b7d3afc..e6b22059f 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -7,8 +7,8 @@ void NNInputs::fillRowVDots( const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal ) { - assert(nnXLen <= NNPos::MAX_BOARD_LEN); - assert(nnYLen <= NNPos::MAX_BOARD_LEN); + assert(nnXLen <= NNPos::MAX_BOARD_LEN_X); + assert(nnYLen <= NNPos::MAX_BOARD_LEN_Y); assert(board.x_size <= nnXLen); assert(board.y_size <= nnYLen); std::fill_n(rowBin, NUM_FEATURES_SPATIAL_V_DOTS * nnXLen * nnYLen,false); diff --git a/cpp/neuralnet/openclbackend.cpp b/cpp/neuralnet/openclbackend.cpp index c4bcf2728..b8922bf2c 100644 --- a/cpp/neuralnet/openclbackend.cpp +++ b/cpp/neuralnet/openclbackend.cpp @@ -2482,13 +2482,13 @@ struct Model { nnXLen = nnX; nnYLen = nnY; - if(nnXLen > NNPos::MAX_BOARD_LEN) - throw StringError(Global::strprintf("nnXLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", - nnXLen, NNPos::MAX_BOARD_LEN + if(nnXLen > NNPos::MAX_BOARD_LEN_X) + throw StringError(Global::strprintf("nnXLen (%d) is greater than NNPos::MAX_BOARD_LEN_X (%d)", + nnXLen, NNPos::MAX_BOARD_LEN_X )); - if(nnYLen > NNPos::MAX_BOARD_LEN) - throw StringError(Global::strprintf("nnYLen (%d) is greater than NNPos::MAX_BOARD_LEN (%d)", - nnYLen, NNPos::MAX_BOARD_LEN + if(nnYLen > NNPos::MAX_BOARD_LEN_Y) + throw StringError(Global::strprintf("nnYLen (%d) is greater than NNPos::MAX_BOARD_LEN_Y (%d)", + nnYLen, NNPos::MAX_BOARD_LEN_Y )); numInputChannels = desc->numInputChannels; diff --git a/cpp/neuralnet/opencltuner.cpp b/cpp/neuralnet/opencltuner.cpp index c8ffc3551..58ae3eb4e 100644 --- a/cpp/neuralnet/opencltuner.cpp +++ b/cpp/neuralnet/opencltuner.cpp @@ -3244,8 +3244,8 @@ OpenCLTuneParams OpenCLTuner::loadOrAutoTune( //If not re-tuning per board size, then check if the tune config for the full size is there //And set the nnXLen and nnYLen we'll use for tuning to the full size if(!openCLReTunePerBoardSize) { - nnXLen = NNPos::MAX_BOARD_LEN; - nnYLen = NNPos::MAX_BOARD_LEN; + nnXLen = NNPos::MAX_BOARD_LEN_X; + nnYLen = NNPos::MAX_BOARD_LEN_Y; openCLTunerFile = dir + "/" + OpenCLTuner::defaultFileName(gpuName, nnXLen, nnYLen, modelInfo); try { OpenCLTuneParams loadedParams = loadFromTunerFile(openCLTunerFile,logger); @@ -3464,8 +3464,8 @@ void OpenCLTuner::autoTuneEverything( } for(ModelInfoForTuning modelInfo : modelInfos) { - int nnXLen = NNPos::MAX_BOARD_LEN; - int nnYLen = NNPos::MAX_BOARD_LEN; + int nnXLen = NNPos::MAX_BOARD_LEN_X; + int nnYLen = NNPos::MAX_BOARD_LEN_Y; string dir = OpenCLTuner::defaultDirectory(true,homeDataDirOverride); string openCLTunerFile = dir + "/" + OpenCLTuner::defaultFileName(gpuName, nnXLen, nnYLen, modelInfo); try { diff --git a/cpp/neuralnet/opencltuner.h b/cpp/neuralnet/opencltuner.h index 8f6071ff1..994b66871 100644 --- a/cpp/neuralnet/opencltuner.h +++ b/cpp/neuralnet/opencltuner.h @@ -180,8 +180,8 @@ struct OpenCLTuneParams { }; namespace OpenCLTuner { - constexpr int DEFAULT_X_SIZE = NNPos::MAX_BOARD_LEN; - constexpr int DEFAULT_Y_SIZE = NNPos::MAX_BOARD_LEN; + constexpr int DEFAULT_X_SIZE = NNPos::MAX_BOARD_LEN_X; + constexpr int DEFAULT_Y_SIZE = NNPos::MAX_BOARD_LEN_Y; constexpr int DEFAULT_BATCH_SIZE = 4; constexpr int DEFAULT_WINOGRAD_3X3_TILE_SIZE = 4; diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 1481a6620..4c3351863 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -194,7 +194,7 @@ void GameInitializer::initShared(ConfigParser& cfg, Logger& logger) { else if(cfg.contains("bSizesXY")) { if(cfg.contains("allowRectangleProb")) throw IOError("Cannot specify allowRectangleProb when specifying bSizesXY, please adjust the relative frequency of rectangles yourself"); - allowedBSizes = cfg.getNonNegativeIntDashedPairs("bSizesXY", 2, Board::MAX_LEN); + allowedBSizes = cfg.getNonNegativeIntDashedPairs("bSizesXY", 2, Board::MAX_LEN_X, Board::MAX_LEN_Y); allowedBSizeRelProbs = cfg.getDoubles("bSizeRelProbs",0.0,1e100); double relProbSum = 0.0; diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 35a5644c0..1438d9ea1 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -115,22 +115,22 @@ vector Setup::initializeNNEvaluators( int nnYLen = std::max(defaultNNYLen,2); if(setupFor != SETUP_FOR_DISTRIBUTED) { if(cfg.contains("maxBoardXSizeForNNBuffer" + idxStr)) - nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardXSizeForNNBuffer")) - nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardXSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardSizeForNNBuffer" + idxStr)) - nnXLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_X); else if(cfg.contains("maxBoardSizeForNNBuffer")) - nnXLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnXLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_X); if(cfg.contains("maxBoardYSizeForNNBuffer" + idxStr)) - nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardYSizeForNNBuffer")) - nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardYSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardSizeForNNBuffer" + idxStr)) - nnYLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardSizeForNNBuffer" + idxStr, 2, NNPos::MAX_BOARD_LEN_Y); else if(cfg.contains("maxBoardSizeForNNBuffer")) - nnYLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN); + nnYLen = cfg.getInt("maxBoardSizeForNNBuffer", 2, NNPos::MAX_BOARD_LEN_Y); } bool requireExactNNLen = defaultRequireExactNNLen; @@ -967,12 +967,12 @@ bool Setup::loadDefaultBoardXYSize( int& defaultBoardYSizeRet ) { const int defaultBoardXSize = - cfg.contains("defaultBoardXSize") ? cfg.getInt("defaultBoardXSize",2,Board::MAX_LEN) : - cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN) : + cfg.contains("defaultBoardXSize") ? cfg.getInt("defaultBoardXSize",2,Board::MAX_LEN_X) : + cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN_X) : -1; const int defaultBoardYSize = - cfg.contains("defaultBoardYSize") ? cfg.getInt("defaultBoardYSize",2,Board::MAX_LEN) : - cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN) : + cfg.contains("defaultBoardYSize") ? cfg.getInt("defaultBoardYSize",2,Board::MAX_LEN_Y) : + cfg.contains("defaultBoardSize") ? cfg.getInt("defaultBoardSize",2,Board::MAX_LEN_Y) : -1; if((defaultBoardXSize == -1) != (defaultBoardYSize == -1)) logger.write("Warning: Config specified only one of defaultBoardXSize or defaultBoardYSize and no other board size parameter, ignoring it"); @@ -1121,7 +1121,7 @@ std::unique_ptr Setup::loadAndPruneAutoPatternBonusTables(Con bool suc = Global::tryStringToInt(pieces[0],boardXSize) && Global::tryStringToInt(pieces[1],boardYSize); if(!suc) continue; - if(boardXSize < 2 || boardXSize > Board::MAX_LEN || boardYSize < 2 || boardYSize > Board::MAX_LEN) + if(boardXSize < 2 || boardXSize > Board::MAX_LEN_X || boardYSize < 2 || boardYSize > Board::MAX_LEN_Y) continue; string dirPath = baseDir + "/" + dirName; diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index 5dff72cb2..d4e238c84 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -117,9 +117,9 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, assert(logger != NULL); nnXLen = nnEval->getNNXLen(); nnYLen = nnEval->getNNYLen(); - assert(nnXLen > 0 && nnXLen <= NNPos::MAX_BOARD_LEN); - assert(nnYLen > 0 && nnYLen <= NNPos::MAX_BOARD_LEN); - policySize = NNPos::getPolicySize(nnXLen,nnYLen); + assert(nnXLen > 0 && nnXLen <= NNPos::MAX_BOARD_LEN_X); + assert(nnYLen > 0 && nnYLen <= NNPos::MAX_BOARD_LEN_Y); + policySize = NNPos::getPolicySize(nnXLen, nnYLen); if(humanEvaluator != NULL) { if(humanEvaluator->getNNXLen() != nnXLen || humanEvaluator->getNNYLen() != nnYLen) diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 2ff9fc68a..b2d9aa774 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -83,7 +83,7 @@ HASH: 3B29F9557D2712A5BC982D218680927D )"); checkStartPos("Cross on standard size", Rules::START_POS_CROSS, 39, 32, R"( -HASH: 7881733B5E132EF52D2A74B8FEF83B01 +HASH: 516E1ABBA0D6B69A0B3D17C9E34E52F7 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -153,7 +153,7 @@ HASH: 249F175819EA8FDE47F8676E655A06DE )"); checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, 39, 32, R"( -HASH: E3384654F950CB67E8EBDE0B86A2DAAD +HASH: CAD72FD407955308CEFCBD7A9B14B35B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -225,7 +225,7 @@ HASH: 445D50D7A61C47CE2730BBB97A2B3C96 )"); checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, 39, 32, R"( -HASH: 03758E799934E32DFFAEA834AA838030 +HASH: 2A9AE7F967F17B42D9B9CB45B735E9C6 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . diff --git a/cpp/tests/testnnevalcanary.cpp b/cpp/tests/testnnevalcanary.cpp index f6c07a0bc..100ecd3a3 100644 --- a/cpp/tests/testnnevalcanary.cpp +++ b/cpp/tests/testnnevalcanary.cpp @@ -470,8 +470,8 @@ static std::shared_ptr nnOutputOfJson(const std::string& s) { nnOutput->policyOptimismUsed = input["policyOptimismUsed"].get(); nnOutput->nnXLen = input["nnXLen"].get(); nnOutput->nnYLen = input["nnYLen"].get(); - testAssert(nnOutput->nnXLen >= 2 && nnOutput->nnXLen <= NNPos::MAX_BOARD_LEN); - testAssert(nnOutput->nnYLen >= 2 && nnOutput->nnYLen <= NNPos::MAX_BOARD_LEN); + testAssert(nnOutput->nnXLen >= 2 && nnOutput->nnXLen <= NNPos::MAX_BOARD_LEN_X); + testAssert(nnOutput->nnYLen >= 2 && nnOutput->nnYLen <= NNPos::MAX_BOARD_LEN_Y); std::vector whiteOwnerMap = input["whiteOwnerMap"].get>(); testAssert(whiteOwnerMap.size() == nnOutput->nnXLen*nnOutput->nnYLen); nnOutput->whiteOwnerMap = new float[nnOutput->nnXLen*nnOutput->nnYLen]; diff --git a/cpp/tests/tests.h b/cpp/tests/tests.h index 4429cbccd..71bc03134 100644 --- a/cpp/tests/tests.h +++ b/cpp/tests/tests.h @@ -129,7 +129,7 @@ namespace TestCommon { constexpr int MIN_BENCHMARK_SGF_DATA_SIZE = 7; constexpr int MAX_BENCHMARK_SGF_DATA_SIZE = 19; - constexpr int DEFAULT_BENCHMARK_SGF_DATA_SIZE = std::min(Board::DEFAULT_LEN,MAX_BENCHMARK_SGF_DATA_SIZE); + constexpr int DEFAULT_BENCHMARK_SGF_DATA_SIZE = std::min(std::max(Board::DEFAULT_LEN_X, Board::DEFAULT_LEN_Y),MAX_BENCHMARK_SGF_DATA_SIZE); std::string getBenchmarkSGFData(int boardSize); std::vector getMultiGameSize9Data(); diff --git a/cpp/tests/testsearch.cpp b/cpp/tests/testsearch.cpp index 1c86d2351..4778dd207 100644 --- a/cpp/tests/testsearch.cpp +++ b/cpp/tests/testsearch.cpp @@ -184,7 +184,7 @@ void Tests::runSearchTests(const string& modelFile, bool inputsNHWC, bool useNHW const bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runBasicPositions(nnEval, logger); delete nnEval; diff --git a/cpp/tests/testsearchnonn.cpp b/cpp/tests/testsearchnonn.cpp index 862bce25c..f5cf38ad4 100644 --- a/cpp/tests/testsearchnonn.cpp +++ b/cpp/tests/testsearchnonn.cpp @@ -36,7 +36,7 @@ void Tests::runNNLessSearchTests() { cout << "Basic search with debugSkipNeuralNet and chosen move randomization" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 100; Search* search = new Search(params, nnEval, &logger, "autoSearchRandSeed"); @@ -119,7 +119,7 @@ void Tests::runNNLessSearchTests() { cout << "Testing preservation of search tree across moves" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 100; params.cpuctExploration *= 2; @@ -251,7 +251,7 @@ o..oo.x { cout << "First with no pruning" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; Search* search = new Search(params, nnEval, &logger, "autoSearchRandSeed3"); @@ -274,7 +274,7 @@ o..oo.x { cout << "Next, with rootPruneUselessMoves" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.rootPruneUselessMoves = true; @@ -310,7 +310,7 @@ o..oo.x { cout << "Searching on the opponent, the move before" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1b",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1b",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.rootPruneUselessMoves = true; @@ -382,7 +382,7 @@ o..o.oo { cout << "First with no pruning" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"seed1",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 400; params.dynamicScoreUtilityFactor = 0.5; @@ -1601,7 +1601,7 @@ xxxxxxxxx cout << "Testing coherence of search tree recursive walking" << endl; cout << "===================================================================" << endl; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,0,true,false,false,true,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,0,true,false,false,true,false); SearchParams params; params.maxVisits = 1000; params.dynamicScoreUtilityFactor = 3.0; diff --git a/cpp/tests/testsearchv3.cpp b/cpp/tests/testsearchv3.cpp index 97f933196..3e24cd788 100644 --- a/cpp/tests/testsearchv3.cpp +++ b/cpp/tests/testsearchv3.cpp @@ -524,9 +524,9 @@ void Tests::runSearchTestsV3(const string& modelFile, bool inputsNHWC, bool useN const bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); NNEvaluator* nnEval11 = startNNEval(modelFile,logger,"",11,11,symmetry,inputsNHWC,useNHWC,useFP16,false,false); - NNEvaluator* nnEvalPTemp = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEvalPTemp = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runOwnershipAndMisc(nnEval,nnEval11,nnEvalPTemp,logger); delete nnEval; delete nnEval11; diff --git a/cpp/tests/testsearchv9.cpp b/cpp/tests/testsearchv9.cpp index 2bc4533c4..9b4d3d687 100644 --- a/cpp/tests/testsearchv9.cpp +++ b/cpp/tests/testsearchv9.cpp @@ -762,7 +762,7 @@ void Tests::runSearchTestsV9(const string& modelFile, bool inputsNHWC, bool useN Logger logger(nullptr, logToStdout, logToStderr, logTime); int symmetry = 4; - NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,symmetry,inputsNHWC,useNHWC,useFP16,false,false); + NNEvaluator* nnEval = startNNEval(modelFile,logger,"",NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,symmetry,inputsNHWC,useNHWC,useFP16,false,false); runV9Positions(nnEval, logger); delete nnEval; diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index ef0c8d3ef..99e8fba40 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -998,7 +998,7 @@ Last moves C3 B4 C4 D4 D3 C2 B3 //============================================================================ - if(Board::MAX_LEN >= 37) + if constexpr(std::min(Board::MAX_LEN_X, Board::MAX_LEN_Y) >= 37) { const char* name = "Giant Sgf parse test"; string sgfStr = "(;GM[1]FF[4]CA[UTF-8]ST[2]RU[Chinese]SZ[37]KM[0.00];B[dd];W[Hd];B[HH];W[dH];B[dG];W[eG];B[eF];W[Gd];B[Ge];W[He];B[ee];W[GG];B[ss])"; diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index edc3b4d32..856a0546a 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -16,8 +16,8 @@ static NNEvaluator* startNNEval( const string& modelName = modelFile; vector gpuIdxByServerThread = {0}; int maxBatchSize = 16; - int nnXLen = NNPos::MAX_BOARD_LEN; - int nnYLen = NNPos::MAX_BOARD_LEN; + int nnXLen = NNPos::MAX_BOARD_LEN_X; + int nnYLen = NNPos::MAX_BOARD_LEN_Y; bool requireExactNNLen = false; int nnCacheSizePowerOfTwo = 16; int nnMutexPoolSizePowerOfTwo = 12; diff --git a/cpp/tests/tinymodel.cpp b/cpp/tests/tinymodel.cpp index 510953652..b34567df3 100644 --- a/cpp/tests/tinymodel.cpp +++ b/cpp/tests/tinymodel.cpp @@ -87,7 +87,7 @@ NNEvaluator* TinyModelTest::runTinyModelTest(const string& baseDir, Logger& logg const string expectedSha256 = ""; NNEvaluator* nnEval = Setup::initializeNNEvaluator( "tinyModel",tmpModelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,maxBatchSize,requireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,maxBatchSize,requireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); nnEval->setDoRandomize(false); @@ -247,7 +247,7 @@ NNEvaluator* TinyModelTest::runTinyModelTest(const string& baseDir, Logger& logg const string expectedSha256 = ""; NNEvaluator* nnEval = Setup::initializeNNEvaluator( "tinyModel",tmpModelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,maxBatchSize,requireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,maxBatchSize,requireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); nnEval->setDoRandomize(false); @@ -407,7 +407,7 @@ NNEvaluator* TinyModelTest::runTinyModelTest(const string& baseDir, Logger& logg const string expectedSha256 = ""; NNEvaluator* nnEval = Setup::initializeNNEvaluator( "tinyModel",tmpModelFile,expectedSha256,cfg,logger,rand,expectedConcurrentEvals, - NNPos::MAX_BOARD_LEN,NNPos::MAX_BOARD_LEN,maxBatchSize,requireExactNNLen,disableFP16, + NNPos::MAX_BOARD_LEN_X,NNPos::MAX_BOARD_LEN_Y,maxBatchSize,requireExactNNLen,disableFP16, Setup::SETUP_FOR_DISTRIBUTED ); nnEval->setDoRandomize(false); From 611b13ec07b1555d146807459c3d70f13c010215 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 29 Aug 2025 22:23:32 +0200 Subject: [PATCH 05/42] Use dedicated `visited_data` for processing visited locs instead of storing them in `State` --- cpp/game/board.cpp | 9 +++++++++ cpp/game/board.h | 15 +++++++++++---- cpp/game/dotsfield.cpp | 33 +++++++++++++++------------------ 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index d2cdbd809..5c15db506 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -153,6 +153,7 @@ Board::Board(const Board& other) { numWhiteCaptures = other.numWhiteCaptures; numLegalMoves = other.numLegalMoves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); + visited_data.resize(other.visited_data.size(), false); } void Board::init(const int xS, const int yS, const Rules& initRules) @@ -188,6 +189,8 @@ void Board::init(const int xS, const int yS, const Rules& initRules) chain_data.resize(MAX_ARR_SIZE); chain_head.resize(MAX_ARR_SIZE); next_in_chain.resize(MAX_ARR_SIZE); + } else { + visited_data.resize(getMaxArrSize(x_size, y_size), false); } Location::getAdjacentOffsets(adj_offsets, x_size, isDots()); @@ -2464,6 +2467,12 @@ void Board::checkConsistency() const { for(int i = 0; i < 8; i++) if(tmpAdjOffsets[i] != adj_offsets[i]) throw StringError(errLabel + "Corrupted adj_offsets array"); + + for (const bool visited : visited_data) { + if (visited) { + throw StringError("Visited data always should be invalidated"); + } + } } bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const { diff --git a/cpp/game/board.h b/cpp/game/board.h index 6b46396da..3956a3ceb 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -106,6 +106,10 @@ namespace Location //Simple ko rule only. //Does not enforce player turn order. +constexpr static int getMaxArrSize(const int x_size, const int y_size) { + return (x_size+1)*(y_size+2)+1; +} + struct Board { //Initialization------------------------------ @@ -131,7 +135,7 @@ struct Board static constexpr int DEFAULT_LEN_X = std::min(MAX_LEN_X,19); //Default x edge length for board if unspecified static constexpr int DEFAULT_LEN_Y = std::min(MAX_LEN_Y,19); //Default y edge length for board if unspecified static constexpr int MAX_PLAY_SIZE = MAX_LEN_X * MAX_LEN_Y; //Maximum number of playable spaces - static constexpr int MAX_ARR_SIZE = (MAX_LEN_X+1)*(MAX_LEN_Y+2)+1; //Maximum size of arrays needed + static constexpr int MAX_ARR_SIZE = getMaxArrSize(MAX_LEN_X, MAX_LEN_Y); //Maximum size of arrays needed //Location used to indicate an invalid spot on the board. static constexpr Loc NULL_LOC = 0; @@ -424,6 +428,7 @@ struct Board mutable std::vector closureOrInvalidateLocsBuffer = std::vector(); mutable std::vector territoryLocationsBuffer = std::vector(); mutable std::vector walkStack = std::vector(); + mutable std::vector visited_data = std::vector(); // Dots game functions [[nodiscard]] bool wouldBeCaptureDots(Loc loc, Player pla) const; @@ -444,11 +449,13 @@ struct Board void invalidateAdjacentEmptyTerritoryIfNeeded(Loc loc); void makeMoveAndCalculateCapturesAndBases(Player pla, Loc loc, bool isSuicideLegal, std::vector& captures, std::vector& bases) const; - void setVisited(Loc loc); - void clearVisited(Loc loc); - void clearVisited(const std::vector& locations); int calculateGroundingWhiteScore(Player pla, std::unordered_set& nonGroundedLocs) const; + bool isVisited(Loc loc) const; + void setVisited(Loc loc) const; + void clearVisited(Loc loc) const; + void clearVisited(const std::vector& locations) const; + void init(int xS, int yS, const Rules& initRules); int countHeuristicConnectionLibertiesX2(Loc loc, Player pla) const; bool isLibertyOf(Loc loc, Loc head) const; diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 4916e8c1c..3dc7cab71 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -8,12 +8,9 @@ using namespace std; static constexpr int PLACED_PLAYER_SHIFT = PLAYER_BITS_COUNT; static constexpr int EMPTY_TERRITORY_SHIFT = PLACED_PLAYER_SHIFT + PLAYER_BITS_COUNT; static constexpr int TERRITORY_FLAG_SHIFT = EMPTY_TERRITORY_SHIFT + PLAYER_BITS_COUNT; -static constexpr int VISITED_FLAG_SHIFT = TERRITORY_FLAG_SHIFT + 1; static constexpr State TERRITORY_FLAG = 1 << TERRITORY_FLAG_SHIFT; -static constexpr State VISITED_FLAG = static_cast(1 << VISITED_FLAG_SHIFT); static constexpr State INVALIDATE_TERRITORY_MASK = ~(ACTIVE_MASK | ACTIVE_MASK << EMPTY_TERRITORY_SHIFT); -static constexpr State INVALIDATE_VISITED_MASK = ~VISITED_FLAG; Loc Location::xm1y(Loc loc) { return loc - 1; @@ -91,10 +88,6 @@ Color getActiveColor(const State state) { return static_cast(state & ACTIVE_MASK); } -bool isVisited(const State s) { - return (s & VISITED_FLAG) == VISITED_FLAG; -} - bool isTerritory(const State s) { return (s & TERRITORY_FLAG) == TERRITORY_FLAG; } @@ -135,15 +128,19 @@ bool Board::isDots() const { return rules.isDots; } -void Board::setVisited(const Loc loc) { - colors[loc] = static_cast(colors[loc] | VISITED_FLAG); +bool Board::isVisited(const Loc loc) const { + return visited_data[loc]; +} + +void Board::setVisited(const Loc loc) const { + visited_data[loc] = true; } -void Board::clearVisited(const Loc loc) { - colors[loc] = static_cast(colors[loc] & INVALIDATE_VISITED_MASK); +void Board::clearVisited(const Loc loc) const { + visited_data[loc] = false; } -void Board::clearVisited(const vector& locations) { +void Board::clearVisited(const vector& locations) const { for (const Loc& loc : locations) { clearVisited(loc); } @@ -382,8 +379,9 @@ vector Board::ground(const Player pla, vector& emptyBaseInvali for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { const Loc loc = Location::getLoc(x, y, x_size); + if (isVisited(loc)) continue; - if (const State state = getState(loc); !isVisited(state) && isActive(state, pla)) { + if (const State state = getState(loc); isActive(state, pla)) { bool createRealBase = false; bool grounded = false; getTerritoryLocations(pla, loc, true, createRealBase, grounded); @@ -471,7 +469,7 @@ void Board::tryGetCounterClockwiseClosure(const Loc initialLoc, const Loc startL break; } - if(isVisited(state)) { + if(isVisited(loc)) { // Remove trailing dots Loc lastLoc; do { @@ -576,8 +574,8 @@ void Board::getTerritoryLocations(const Player pla, const Loc firstLoc, const bo FOREACHADJ( Loc adj = loc + ADJOFFSET; - state = getState(adj); - if (!isVisited(state)) { + if (!isVisited(adj)) { + state = getState(adj); const Color activeColor = getActiveColor(state); if (activeColor == C_WALL) { assert(grounding); @@ -696,8 +694,7 @@ void Board::invalidateAdjacentEmptyTerritoryIfNeeded(const Loc loc) { FOREACHADJ( Loc adj = lastLoc + ADJOFFSET; - State state = getState(adj); - if (getEmptyTerritoryColor(state) != C_EMPTY && !isVisited(state)) { + if (!isVisited(adj) && getEmptyTerritoryColor(getState(adj)) != C_EMPTY) { closureOrInvalidateLocsBuffer.push_back(adj); setState(adj, C_EMPTY); setVisited(adj); From ca35e329894fe38371aa546cd5b25dbc7aabc9f4 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 13 Dec 2025 20:58:25 +0100 Subject: [PATCH 06/42] Implement iterative grounding detection and score calculating in case of grounding --- cpp/game/board.cpp | 28 +- cpp/game/board.h | 34 ++- cpp/game/boardhistory.cpp | 30 +- cpp/game/boardhistory.h | 1 + cpp/game/common.h | 3 + cpp/game/dotsfield.cpp | 315 ++++++++++++-------- cpp/neuralnet/nninputs.cpp | 2 + cpp/neuralnet/nninputsdots.cpp | 3 - cpp/tests/testdotsbasic.cpp | 516 ++++++++++++++++++++++++++++++--- cpp/tests/testdotsutils.h | 8 + 10 files changed, 746 insertions(+), 194 deletions(-) diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 5c15db506..52f04769d 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -111,7 +111,7 @@ bool Location::isNearCentral(Loc loc, int x_size, int y_size) { Board::Base::Base(Player newPla, const std::vector& rollbackLocations, const std::vector& rollbackStates, - bool isReal + const bool isReal ) { pla = newPla; rollback_locations = rollbackLocations; @@ -151,6 +151,8 @@ Board::Board(const Board& other) { pos_hash = other.pos_hash; numBlackCaptures = other.numBlackCaptures; numWhiteCaptures = other.numWhiteCaptures; + blackScoreIfWhiteGrounds = other.blackScoreIfWhiteGrounds; + whiteScoreIfBlackGrounds = other.whiteScoreIfBlackGrounds; numLegalMoves = other.numLegalMoves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); visited_data.resize(other.visited_data.size(), false); @@ -166,8 +168,12 @@ void Board::init(const int xS, const int yS, const Rules& initRules) y_size = yS; rules = initRules; - for(int i = 0; i < MAX_ARR_SIZE; i++) + for(int i = 0; i < MAX_ARR_SIZE; i++) { colors[i] = C_WALL; + if (rules.isDots) { + setGrounded(i); + } + } for(int y = 0; y < y_size; y++) { @@ -183,6 +189,8 @@ void Board::init(const int xS, const int yS, const Rules& initRules) pos_hash = ZOBRIST_SIZE_X_HASH[x_size] ^ ZOBRIST_SIZE_Y_HASH[y_size]; numBlackCaptures = 0; numWhiteCaptures = 0; + blackScoreIfWhiteGrounds = 0; + whiteScoreIfBlackGrounds = 0; numLegalMoves = xS * yS; if (!rules.isDots) { @@ -2488,10 +2496,14 @@ bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool ch return false; if(checkSimpleKo && ko_loc != other.ko_loc) return false; - if(checkNumCaptures && numBlackCaptures != other.numBlackCaptures) - return false; - if(checkNumCaptures && numWhiteCaptures != other.numWhiteCaptures) + if(checkNumCaptures && ( + numBlackCaptures != other.numBlackCaptures || + numWhiteCaptures != other.numWhiteCaptures || + blackScoreIfWhiteGrounds != other.blackScoreIfWhiteGrounds || + whiteScoreIfBlackGrounds != other.whiteScoreIfBlackGrounds + )) { return false; + } if(pos_hash != other.pos_hash) return false; for(int i = 0; i(); board.numWhiteCaptures = data["numWhiteCaptures"].get(); + board.blackScoreIfWhiteGrounds = data.value(BLACK_SCORE_IF_WHITE_GROUNDS_KEY, 0); + board.whiteScoreIfBlackGrounds = data.value(WHITE_SCORE_IF_BLACK_GROUNDS_KEY, 0); return board; } diff --git a/cpp/game/board.h b/cpp/game/board.h index 3956a3ceb..1b386f676 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -47,6 +47,8 @@ Color getPlacedDotColor(State s); Color getEmptyTerritoryColor(State s); +bool isGrounded(State state); + //Conversions for players and colors namespace PlayerIO { char colorToChar(Color c); @@ -182,10 +184,10 @@ struct Board /* }; */ struct Base { - Player pla{}; - bool is_real{}; std::vector rollback_locations; std::vector rollback_states; + Player pla{}; + bool is_real{}; Base() = default; Base(Player newPla, const std::vector& rollbackLocations, const std::vector& rollbackStates, bool isReal); @@ -202,6 +204,7 @@ struct Board State previousState; std::vector bases; std::vector emptyBaseInvalidateLocations; + std::vector groundingLocations; MoveRecord() = default; @@ -215,11 +218,12 @@ struct Board // Constructor for Dots game MoveRecord( - Loc initLoc, - Player initPla, - State initPreviousState, - const std::vector& initBases, - const std::vector& initEmptyBaseInvalidateLocations + Loc newLoc, + Player newPla, + State newPreviousState, + const std::vector& newBases, + const std::vector& newEmptyBaseInvalidateLocations, + const std::vector& newGroundingLocations ); }; @@ -410,6 +414,10 @@ struct Board int numBlackCaptures; //Number of b stones captured, informational and used by board history when clearing pos int numWhiteCaptures; //Number of w stones captured, informational and used by board history when clearing pos + // Useful for fast calculation of the game result and finishing the game + int blackScoreIfWhiteGrounds; + int whiteScoreIfBlackGrounds; + // Offsets to add to get clockwise traverse short adj_offsets[8]; @@ -436,21 +444,23 @@ struct Board MoveRecord playMoveAssumeLegalDots(Loc loc, Player pla); MoveRecord tryPlayMove(Loc loc, Player pla, bool isSuicideLegal); void undoDots(MoveRecord& moveRecord); - Base captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp); - std::vector tryCapture(Loc loc, Player pla, bool emptyBaseCapturing); + std::vector fillGrounding(Loc loc); + Base captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp, bool& isGrounded); + std::vector tryCapture(Loc loc, Player pla, bool emptyBaseCapturing, bool& atLeastOneRealBaseIsGrounded); std::vector ground(Player pla, std::vector& emptyBaseInvalidatePositions); void getUnconnectedLocations(Loc loc, Player pla) const; void checkAndAddUnconnectedLocation(Player checkPla,Player currentPla,Loc addLoc1,Loc addLoc2) const; - void tryGetCounterClockwiseClosure(Loc initialLoc, Loc startLoc, Player pla); + void tryGetCounterClockwiseClosure(Loc initialLoc, Loc startLoc, Player pla) const; Base buildBase(const std::vector& closure, Player pla); - void getTerritoryLocations(Player pla, Loc firstLoc, bool grounding, bool& createRealBase, bool& grounded); + void getTerritoryLocations(Player pla, Loc firstLoc, bool grounding, bool& createRealBase) const; Base createBaseAndUpdateStates(Player basePla, bool isReal); void updateScoreAndHashForTerritory(Loc loc, State state, Player basePla, bool rollback); void invalidateAdjacentEmptyTerritoryIfNeeded(Loc loc); void makeMoveAndCalculateCapturesAndBases(Player pla, Loc loc, bool isSuicideLegal, std::vector& captures, std::vector& bases) const; - int calculateGroundingWhiteScore(Player pla, std::unordered_set& nonGroundedLocs) const; + void setGrounded(Loc loc); + void clearGrounded(Loc loc); bool isVisited(Loc loc) const; void setVisited(Loc loc) const; void clearVisited(Loc loc) const; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 24be1f053..ae6fdd4c4 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -739,7 +739,9 @@ void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ whiteBonusScore += (presumedNextMovePla == P_WHITE ? 0.5f : -0.5f); } - setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + if (!rules.isDots || !isGameFinished) { + setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); + } isScored = true; isNoResult = false; isResignation = false; @@ -752,6 +754,22 @@ void BoardHistory::endAndScoreGameNow(const Board& board) { endAndScoreGameNow(board,area); } +bool BoardHistory::isGroundingWinsGame(const Board& board, const Player pla, float& whiteScore) const { + assert(rules.isDots); + + if (pla == P_BLACK && board.whiteScoreIfBlackGrounds + rules.komi < 0) { + whiteScore = board.whiteScoreIfBlackGrounds + rules.komi; + return true; + } + + if (pla == P_WHITE && board.blackScoreIfWhiteGrounds - rules.komi < 0) { + whiteScore = -board.blackScoreIfWhiteGrounds + rules.komi; + return true; + } + + return false; +} + void BoardHistory::endGameIfAllPassAlive(const Board& board) { assert(rules.isDots == board.isDots()); @@ -764,15 +782,7 @@ void BoardHistory::endGameIfAllPassAlive(const Board& board) { gameOver = true; normalizedWhiteScoreIfGroundingAlive = static_cast(board.numBlackCaptures - board.numWhiteCaptures) + rules.komi; } else { - Board::MoveRecord moveRecord = const_cast(board).playMoveRecorded(Board::PASS_LOC, presumedNextMovePla); - const float whiteScoreAfterNextPlaGrounding = static_cast(board.numBlackCaptures - board.numWhiteCaptures) + rules.komi; - const_cast(board).undo(moveRecord); - - if (presumedNextMovePla == P_BLACK && whiteScoreAfterNextPlaGrounding < 0.0f || - presumedNextMovePla == P_WHITE && whiteScoreAfterNextPlaGrounding > 0.0f) { - gameOver = true; - normalizedWhiteScoreIfGroundingAlive = whiteScoreAfterNextPlaGrounding; - } + gameOver = isGroundingWinsGame(board, presumedNextMovePla, normalizedWhiteScoreIfGroundingAlive); } if (gameOver) { diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index f351f18d8..6dfbf4546 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -172,6 +172,7 @@ struct BoardHistory { void endGameIfAllPassAlive(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); + bool isGroundingWinsGame(const Board& board, Player pla, float& whiteScore) const; void endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]); void getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; diff --git a/cpp/game/common.h b/cpp/game/common.h index 18cc632f2..cde12ddea 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -9,6 +9,9 @@ const std::string DOTS_CAPTURE_EMPTY_BASES_KEY = "dotsCaptureEmptyBases"; const std::string START_POS_KEY = "startPos"; const std::string START_POSES_KEY = "startPoses"; +const std::string BLACK_SCORE_IF_WHITE_GROUNDS_KEY = "blackScoreIfWhiteGrounds"; +const std::string WHITE_SCORE_IF_BLACK_GROUNDS_KEY = "whiteScoreIfBlackGrounds"; + // Player typedef int8_t Player; static constexpr Player P_BLACK = 1; diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 3dc7cab71..d844b5025 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -8,8 +8,10 @@ using namespace std; static constexpr int PLACED_PLAYER_SHIFT = PLAYER_BITS_COUNT; static constexpr int EMPTY_TERRITORY_SHIFT = PLACED_PLAYER_SHIFT + PLAYER_BITS_COUNT; static constexpr int TERRITORY_FLAG_SHIFT = EMPTY_TERRITORY_SHIFT + PLAYER_BITS_COUNT; +static constexpr int GROUNDED_FLAG_SHIFT = TERRITORY_FLAG_SHIFT + 1; static constexpr State TERRITORY_FLAG = 1 << TERRITORY_FLAG_SHIFT; +static constexpr State GROUNDED_FLAG = static_cast(1 << GROUNDED_FLAG_SHIFT); static constexpr State INVALIDATE_TERRITORY_MASK = ~(ACTIVE_MASK | ACTIVE_MASK << EMPTY_TERRITORY_SHIFT); Loc Location::xm1y(Loc loc) { @@ -128,6 +130,24 @@ bool Board::isDots() const { return rules.isDots; } +bool isGrounded(const State state) { + return (state & GROUNDED_FLAG) == GROUNDED_FLAG; +} + +bool isGroundedOrWall(const State state, const Player pla) { + // Use bit tricks for grounding detecting. + // If the active player is C_WALL, then the result is also true. + return (state & GROUNDED_FLAG) == GROUNDED_FLAG && (state & pla) == pla; +} + +void Board::setGrounded(const Loc loc) { + colors[loc] = static_cast(colors[loc] | GROUNDED_FLAG); +} + +void Board::clearGrounded(const Loc loc) { + colors[loc] = static_cast(colors[loc] & ~GROUNDED_FLAG); +} + bool Board::isVisited(const Loc loc) const { return visited_data[loc]; } @@ -147,56 +167,39 @@ void Board::clearVisited(const vector& locations) const { } int Board::calculateGroundingWhiteScore(Color* result) const { - auto nonGroundedLocs = unordered_set(); - - const int whiteScoreAfterBlackGrounding = calculateGroundingWhiteScore(P_BLACK, nonGroundedLocs); - const int whiteScoreAfterWhiteGrounding = calculateGroundingWhiteScore(P_WHITE, nonGroundedLocs); - for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { const Loc loc = Location::getLoc(x, y, x_size); - if (const Color color = getColor(loc); color == C_EMPTY || nonGroundedLocs.count(loc) > 0) { - result[loc] = C_EMPTY; + if (const State state = getState(loc); isGrounded(state)) { + const Color activeColor = getActiveColor(state); + assert(activeColor != C_EMPTY); + result[loc] = activeColor; } else { - result[loc] = color; // Fill only grounded locs + result[loc] = C_EMPTY; } } } - return whiteScoreAfterBlackGrounding + whiteScoreAfterWhiteGrounding; -} - -int Board::calculateGroundingWhiteScore(Player pla, unordered_set& nonGroundedLocs) const { - auto emptyBaseInvalidateLocations = vector(); - const auto bases = const_cast(this)->ground(pla, emptyBaseInvalidateLocations); - auto moveRecord = MoveRecord(PASS_LOC, pla, getState(PASS_LOC), bases, emptyBaseInvalidateLocations); - for (Base& base : moveRecord.bases) { - for (Loc& loc : base.rollback_locations) { - nonGroundedLocs.insert(loc); - } - } - - const int whiteScore = numBlackCaptures - numWhiteCaptures; - - const_cast(this)->undoDots(moveRecord); - return whiteScore; + return whiteScoreIfBlackGrounds - blackScoreIfWhiteGrounds; } Board::MoveRecord::MoveRecord( - const Loc initLoc, - const Player initPla, - const State initPreviousState, - const vector& initBases, - const vector& initEmptyBaseInvalidateLocations + const Loc newLoc, + const Player newPla, + const State newPreviousState, + const vector& newBases, + const vector& newEmptyBaseInvalidateLocations, + const vector& newGroundingLocations ) { ko_loc = NULL_LOC; capDirs = 0; - loc = initLoc; - pla = initPla; - previousState = initPreviousState; - bases = initBases; - emptyBaseInvalidateLocations = initEmptyBaseInvalidateLocations; + loc = newLoc; + pla = newPla; + previousState = newPreviousState; + bases = newBases; + emptyBaseInvalidateLocations = newEmptyBaseInvalidateLocations; + groundingLocations = newGroundingLocations; } bool Board::isSuicideDots(const Loc loc, const Player pla) const { @@ -239,6 +242,7 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe vector bases; vector initEmptyBaseInvalidateLocations; + vector newGroundingLocations; if (loc == PASS_LOC) { initEmptyBaseInvalidateLocations = vector(); @@ -249,13 +253,14 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe pos_hash ^= hashValue; numLegalMoves--; - bases = tryCapture(loc, pla, false); + bool atLeastOneRealBaseIsGrounded = false; + bases = tryCapture(loc, pla, false, atLeastOneRealBaseIsGrounded); const Color opp = getOpp(pla); if (bases.empty()) { if (getEmptyTerritoryColor(originalState) == opp) { if (isSuicideLegal) { - bases.push_back(captureWhenEmptyTerritoryBecomesRealBase(loc, opp)); + bases.push_back(captureWhenEmptyTerritoryBecomesRealBase(loc, opp, atLeastOneRealBaseIsGrounded)); } else { colors[loc] = originalState; pos_hash ^= hashValue; @@ -269,12 +274,59 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe initEmptyBaseInvalidateLocations = vector(closureOrInvalidateLocsBuffer); } } + + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else if (pla == P_WHITE) { + blackScoreIfWhiteGrounds++; + } + + if (atLeastOneRealBaseIsGrounded) { + newGroundingLocations = fillGrounding(loc); + } else if( + const Player locActivePlayer = getColor(loc); // Can't use pla because of a possible suicidal move + isGroundedOrWall(getState(Location::xm1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xym1(loc, x_size)), locActivePlayer) || + isGroundedOrWall(getState(Location::xp1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xyp1(loc, x_size)), locActivePlayer) + ) { + newGroundingLocations = fillGrounding(loc); + } } - return {loc, pla, originalState, bases, initEmptyBaseInvalidateLocations}; + return {loc, pla, originalState, bases, initEmptyBaseInvalidateLocations, newGroundingLocations}; } void Board::undoDots(MoveRecord& moveRecord) { + const bool isGroundingMove = moveRecord.loc == PASS_LOC; + + for (const Loc& loc : moveRecord.groundingLocations) { + const State state = getState(loc); + const Player mainPla = getActiveColor(state); + if (getPlacedDotColor(state) != C_EMPTY) { + if (mainPla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else { + blackScoreIfWhiteGrounds++; + } + } + clearGrounded(loc); + } + + if (!isGroundingMove) { + if (moveRecord.pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else if (moveRecord.pla == P_WHITE) { + blackScoreIfWhiteGrounds--; + } + } + + const Player emptyTerritoryPlayer = isGroundingMove ? moveRecord.pla : getOpp(moveRecord.pla); + for (const Loc& loc : moveRecord.emptyBaseInvalidateLocations) { + assert(0 == getState(loc)); + setState(loc, static_cast(emptyTerritoryPlayer << EMPTY_TERRITORY_SHIFT)); + } + for (auto it = moveRecord.bases.rbegin(); it != moveRecord.bases.rend(); ++it) { for (size_t index = 0; index < it->rollback_locations.size(); index++) { const State rollbackState = it->rollback_states[index]; @@ -286,21 +338,52 @@ void Board::undoDots(MoveRecord& moveRecord) { } } - const bool isGrounding = moveRecord.loc == PASS_LOC; - - const Player emptyTerritoryPlayer = isGrounding ? moveRecord.pla : getOpp(moveRecord.pla); - for (const Loc& loc : moveRecord.emptyBaseInvalidateLocations) { - setState(loc, static_cast(emptyTerritoryPlayer << EMPTY_TERRITORY_SHIFT)); - } - - if (!isGrounding) { + if (!isGroundingMove) { setState(moveRecord.loc, moveRecord.previousState); pos_hash ^= ZOBRIST_BOARD_HASH[moveRecord.loc][moveRecord.pla]; numLegalMoves++; } } -Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, const Player opp) { +vector Board::fillGrounding(const Loc loc) { + vector groundedLocs; + + walkStack.clear(); + walkStack.push_back(loc); + const Player pla = getColor(loc); + assert(pla != C_EMPTY && pla != C_WALL); + setGrounded(loc); + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else { + blackScoreIfWhiteGrounds--; + } + groundedLocs.push_back(loc); + + while (!walkStack.empty()) { + const Loc currentLoc = walkStack.back(); + walkStack.pop_back(); + + forEachAdjacent(currentLoc, [&](const Loc adj) { + if (const State state = getState(adj); !isGrounded(state) && isActive(state, pla)) { + setGrounded(adj); + if (const Player placedColor = getPlacedDotColor(state); placedColor != C_EMPTY) { + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds--; + } else { + blackScoreIfWhiteGrounds--; + } + } + groundedLocs.push_back(adj); + walkStack.push_back(adj); + } + }); + } + + return groundedLocs; +} + +Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, const Player opp, bool& isGrounded) { Loc loc = initLoc; // Searching for an opponent dot that makes a closure that contains the `initialPosition`. @@ -311,7 +394,7 @@ Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, c // Try to peek an active opposite player dot if (getColor(loc) != opp) continue; - vector oppBases = tryCapture(loc, opp, true); + vector oppBases = tryCapture(loc, opp, true, isGrounded); // The found base always should be real and include the `iniLoc` for (const Base& oppBase : oppBases) { if (oppBase.is_real) { @@ -324,7 +407,7 @@ Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, c return {}; } -vector Board::tryCapture(const Loc loc, const Player pla, const bool emptyBaseCapturing) { +vector Board::tryCapture(const Loc loc, const Player pla, const bool emptyBaseCapturing, bool& atLeastOneRealBaseIsGrounded) { getUnconnectedLocations(loc, pla); auto currentClosures = vector>(); @@ -332,7 +415,7 @@ vector Board::tryCapture(const Loc loc, const Player pla, const boo unconnectedLocationsBufferSize < minNumberOfConnections) return {}; for (int index = 0; index < unconnectedLocationsBufferSize; index++) { - Loc unconnectedLoc = unconnectedLocationsBuffer[index]; + const Loc unconnectedLoc = unconnectedLocationsBuffer[index]; // Optimization: it doesn't make sense to check the latest unconnected dot // when all previous connections form minimal bases @@ -365,49 +448,50 @@ vector Board::tryCapture(const Loc loc, const Player pla, const boo } auto resultBases = vector(); + atLeastOneRealBaseIsGrounded = false; for (const vector& currentClosure: currentClosures) { - resultBases.push_back(buildBase(currentClosure, pla)); + Base base = buildBase(currentClosure, pla); + resultBases.push_back(base); + + if (!atLeastOneRealBaseIsGrounded && base.is_real) { + for (const Loc& closureLoc : currentClosure) { + forEachAdjacent(closureLoc, [&](const Loc adj) { + atLeastOneRealBaseIsGrounded = atLeastOneRealBaseIsGrounded || isGroundedOrWall(getState(adj), base.pla); + }); + if (atLeastOneRealBaseIsGrounded) { + break; + } + } + } } + return std::move(resultBases); } vector Board::ground(const Player pla, vector& emptyBaseInvalidatePositions) { - auto processedLocs = vector(); const Color opp = getOpp(pla); auto resultBases = vector(); for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { const Loc loc = Location::getLoc(x, y, x_size); - if (isVisited(loc)) continue; - - if (const State state = getState(loc); isActive(state, pla)) { + if (const State state = getState(loc); !isGrounded(state) && isActive(state, pla)) { bool createRealBase = false; - bool grounded = false; - getTerritoryLocations(pla, loc, true, createRealBase, grounded); + getTerritoryLocations(pla, loc, true, createRealBase); assert(createRealBase); - if (!grounded) { - for (const Loc& territoryLoc : territoryLocationsBuffer) { - invalidateAdjacentEmptyTerritoryIfNeeded(territoryLoc); - for (const Loc& invalidateLoc : closureOrInvalidateLocsBuffer) { - emptyBaseInvalidatePositions.push_back(invalidateLoc); - } + for (const Loc& territoryLoc : territoryLocationsBuffer) { + invalidateAdjacentEmptyTerritoryIfNeeded(territoryLoc); + for (const Loc& invalidateLoc : closureOrInvalidateLocsBuffer) { + emptyBaseInvalidatePositions.push_back(invalidateLoc); } - - resultBases.push_back(createBaseAndUpdateStates(opp, createRealBase)); } - for (const Loc& territoryLoc : territoryLocationsBuffer) { - processedLocs.push_back(territoryLoc); - setVisited(territoryLoc); - } + resultBases.push_back(createBaseAndUpdateStates(opp, true)); } } } - clearVisited(processedLocs); - return resultBases; } @@ -434,7 +518,7 @@ void Board::checkAndAddUnconnectedLocation(const Player checkPla, const Player c } } -void Board::tryGetCounterClockwiseClosure(const Loc initialLoc, const Loc startLoc, const Player pla) { +void Board::tryGetCounterClockwiseClosure(const Loc initialLoc, const Loc startLoc, const Player pla) const { closureOrInvalidateLocsBuffer.clear(); closureOrInvalidateLocsBuffer.push_back(initialLoc); setVisited(initialLoc); @@ -530,74 +614,69 @@ Board::Base Board::buildBase(const vector& closure, const Player pla) { const Loc territoryFirstLoc = Location::getNextLocCW(closure.at(1), closure.at(0), x_size); bool createRealBase; - bool grounded = false; - getTerritoryLocations(pla, territoryFirstLoc, false, createRealBase, grounded); - assert(!grounded); + getTerritoryLocations(pla, territoryFirstLoc, false, createRealBase); clearVisited(closure); return createBaseAndUpdateStates(pla, createRealBase); } -void Board::getTerritoryLocations(const Player pla, const Loc firstLoc, const bool grounding, bool& createRealBase, bool& grounded) { +void Board::getTerritoryLocations(const Player pla, const Loc firstLoc, const bool grounding, bool& createRealBase) const { walkStack.clear(); territoryLocationsBuffer.clear(); createRealBase = grounding ? false : rules.dotsCaptureEmptyBases; - grounded = false; const Player opp = getOpp(pla); State state = getState(firstLoc); - if (const Color activeColor = getActiveColor(state); activeColor == C_WALL) { - assert(grounding); - grounded = true; - } else { - bool legalLoc = false; - if (grounding) { - createRealBase = true; - legalLoc = activeColor == pla; - } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory - createRealBase = createRealBase || isPlaced(state, opp); - legalLoc = true; // If no grounding, empty locations can be handled as well - } + Color activeColor = getActiveColor(state); + assert(activeColor != C_WALL); + + bool legalLoc = false; + if (grounding) { + createRealBase = true; + // In a rare case it's possible to encounter an empty ungrounded loc that should be de-facto grounded. + // However, currently it's to set up its grounding due to limitations of the grounding algorithm that doesn't traverse diagonals. + // That's why we have to check adj locs on grounding to prevent adding incorrect locs and causing out of bounds exception. + legalLoc = activeColor == pla && !isGroundedOrWall(state, pla); + } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + legalLoc = true; // If no grounding, empty locations can be handled as well + } - if (legalLoc) { - territoryLocationsBuffer.push_back(firstLoc); - setVisited(firstLoc); - walkStack.push_back(firstLoc); - } + if (legalLoc) { + territoryLocationsBuffer.push_back(firstLoc); + setVisited(firstLoc); + walkStack.push_back(firstLoc); } while (!walkStack.empty()) { const Loc loc = walkStack.back(); walkStack.pop_back(); - FOREACHADJ( - Loc adj = loc + ADJOFFSET; + forEachAdjacent(loc, [&](const Loc adj) { + if (isVisited(adj)) return; - if (!isVisited(adj)) { - state = getState(adj); - const Color activeColor = getActiveColor(state); - if (activeColor == C_WALL) { - assert(grounding); - grounded = true; - } else { - bool isAdjLegal = false; - if (grounding) { - createRealBase = true; - isAdjLegal = activeColor == pla; - } else if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory - createRealBase = createRealBase || isPlaced(state, opp); - isAdjLegal = true; // If no grounding, empty locations can be handled as well - } + state = getState(adj); + activeColor = getActiveColor(state); - if (isAdjLegal) { - territoryLocationsBuffer.push_back(adj); - setVisited(adj); - walkStack.push_back(adj); - } + bool isAdjLegal = false; + if (grounding) { + createRealBase = true; + isAdjLegal = activeColor == pla && !isGroundedOrWall(state, pla); + } else { + assert(activeColor != C_WALL); + if (activeColor != pla || !isTerritory(state)) { // Ignore already captured territory + createRealBase = createRealBase || isPlaced(state, opp); + isAdjLegal = true; // If no grounding, empty locations can be handled as well } } - ) + + if (isAdjLegal) { + territoryLocationsBuffer.push_back(adj); + setVisited(adj); + walkStack.push_back(adj); + } + }); } clearVisited(territoryLocationsBuffer); @@ -650,7 +729,7 @@ void Board::updateScoreAndHashForTerritory(const Loc loc, const State state, con numBlackCaptures--; } } - } else if (isPlaced(state, basePla) && isActive(state, baseOppPla) && rules.dotsFreeCapturedDots) { + } else if (isPlaced(state, basePla) && isActive(state, baseOppPla)) { // No diff for the territory of the current player if (basePla == P_BLACK) { if (!rollback) { diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 9f284b6ed..e9a278d88 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -720,6 +720,8 @@ Board SymmetryHelpers::getSymBoard(const Board& board, int symmetry) { symBoard.numBlackCaptures = board.numBlackCaptures; symBoard.numWhiteCaptures = board.numWhiteCaptures; symBoard.numLegalMoves = board.numLegalMoves; + symBoard.blackScoreIfWhiteGrounds = board.blackScoreIfWhiteGrounds; + symBoard.whiteScoreIfBlackGrounds = board.whiteScoreIfBlackGrounds; } return symBoard; } diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index e6b22059f..2fc7b2474 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -37,9 +37,6 @@ void NNInputs::fillRowVDots( Color grounding[Board::MAX_ARR_SIZE]; board.calculateGroundingWhiteScore(grounding); - auto boardString = board.toString(); - (void)boardString; - for(int y = 0; y& moveRecords) { + Board boardCopy(initialBoard); + BoardHistory boardHistory(boardCopy, P_BLACK, boardCopy.rules, 0); + for (const Board::MoveRecord& moveRecord : moveRecords) { + boardHistory.makeBoardMoveAssumeLegal(boardCopy, moveRecord.loc, moveRecord.pla, nullptr); + } + std::ostringstream sgfStringStream; + WriteSgf::writeSgf(sgfStringStream, "blue", "red", boardHistory, {}); + return sgfStringStream.str(); +} +/** + * Calculates the grounding and result captures without using the grounding flag and incremental calculations. + * It's used for testing to verify incremental grounding algorithms. + */ +void validateGrounding( + const Board& boardBeforeGrounding, + const Board& boardAfterGrounding, + const Player pla, + const vector& moveRecords) { + unordered_set visited_locs; + assert(pla == P_BLACK || pla == P_WHITE); + + int expectedNumBlackCaptures = 0; + int expectedNumWhiteCaptures = 0; + const Player opp = getOpp(pla); + for (int y = 0; y < boardBeforeGrounding.y_size; y++) { + for (int x = 0; x < boardBeforeGrounding.x_size; x++) { + Loc loc = Location::getLoc(x, y, boardBeforeGrounding.x_size); + const State state = boardBeforeGrounding.getState(Location::getLoc(x, y, boardBeforeGrounding.x_size)); + + if (const Color activeColor = getActiveColor(state); activeColor == pla) { + if (visited_locs.count(loc) > 0) + continue; + + bool grounded = false; + + vector walkStack; + vector baseLocs; + walkStack.push_back(loc); + + // Find active territory and calculate its grounding state. + while (!walkStack.empty()) { + Loc curLoc = walkStack.back(); + walkStack.pop_back(); + + if (const Color curActiveColor = getActiveColor(boardBeforeGrounding.getState(curLoc)); curActiveColor == pla) { + if (visited_locs.count(curLoc) == 0) { + visited_locs.insert(curLoc); + baseLocs.push_back(curLoc); + boardBeforeGrounding.forEachAdjacent(curLoc, [&](const Loc& adjLoc) { + walkStack.push_back(adjLoc); + }); + } + } else if (curActiveColor == C_WALL) { + grounded = true; + } + } + + for (const Loc& baseLoc : baseLocs) { + const Color placedDotColor = getPlacedDotColor(boardBeforeGrounding.getState(baseLoc)); + + if (!grounded) { + // If the territory is not grounded, it becomes dead. + // Freed dots don't count because they don't add a value to the opp score (assume they just become be placed). + if (placedDotColor == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } else { + State baseLocState = boardAfterGrounding.getState(baseLoc); + // This check on placed dot color is redundant. + // However, currently it's not possible to always ground empty locs in some rare cases due to limitations of incremental grounding algorithm. + // Fortunately, they don't affect the resulting score. + if (!isGrounded(baseLocState) && getPlacedDotColor(baseLocState) != C_EMPTY) { + Global::fatalError("Loc (" + to_string(Location::getX(baseLoc, boardBeforeGrounding.x_size)) + "; " + + to_string(Location::getY(baseLoc, boardBeforeGrounding.x_size)) + ") " + + " should be grounded. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } + + // If the territory is grounded, count dead dots of the opp player. + if (placedDotColor == opp) { + if (pla == P_BLACK) { + expectedNumWhiteCaptures++; + } else { + expectedNumBlackCaptures++; + } + } + } + } + } else if (activeColor == opp) { // In the case of opp active color, counts only captured dots + if (getPlacedDotColor(state) == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } + } + } + + if (expectedNumBlackCaptures != boardAfterGrounding.numBlackCaptures || expectedNumWhiteCaptures != boardAfterGrounding.numWhiteCaptures) { + Global::fatalError("expectedNumBlackCaptures (" + to_string(expectedNumBlackCaptures) + ")" + + " == board.numBlackCaptures (" + to_string(boardAfterGrounding.numBlackCaptures) + ")" + + " && expectedNumWhiteCaptures (" + to_string(expectedNumWhiteCaptures) + ")" + + " == board.numWhiteCaptures (" + to_string(boardAfterGrounding.numWhiteCaptures) + ")" + + " check is failed. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } +} + +void runDotsStressTestsInternal( + int x_size, + int y_size, + int gamesCount, + bool dotsGame, + int startPos, + bool dotsCaptureEmptyBase, + float komi, + bool suicideAllowed, + float groundingStartCoef, + float groundingEndCoef, + bool performExtraChecks + ) { + assert(groundingStartCoef >= 0 && groundingStartCoef <= 1); + assert(groundingEndCoef >= 0 && groundingEndCoef <= 1); + assert(groundingEndCoef >= groundingStartCoef); -void runDotsStressTestsInternal(int x_size, int y_size, int gamesCount, float groundingAfterCoef, float groundingProb, float komi, bool suicideAllowed, bool checkRollback) { - // TODO: add tests with grounding cout << " Random games" << endl; - cout << " Check rollback: " << boolalpha << checkRollback << endl; + cout << " Game type: " << (dotsGame ? "Dots" : "Go") << endl; + cout << " Start position: " << Rules::writeStartPosRule(startPos) << endl; + if (dotsGame) { + cout << " Capture empty bases: " << boolalpha << dotsCaptureEmptyBase << endl; + } + cout << " Extra checks: " << boolalpha << performExtraChecks << endl; #ifdef NDEBUG cout << " Build: Release" << endl; #else @@ -491,47 +880,79 @@ void runDotsStressTestsInternal(int x_size, int y_size, int gamesCount, float gr Rand rand("runDotsStressTests"); - Rules rules = Rules(false); - Board initialBoard = Board(x_size, y_size, rules); - - int tryGroundingAfterMove = groundingAfterCoef * initialBoard.numLegalMoves; + Rules rules = dotsGame ? Rules(dotsGame, startPos, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); + auto initialBoard = Board(x_size, y_size, rules); vector randomMoves = vector(); randomMoves.reserve(initialBoard.numLegalMoves); for(int y = 0; y < initialBoard.y_size; y++) { for(int x = 0; x < initialBoard.x_size; x++) { - randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); + Loc loc = Location::getLoc(x, y, initialBoard.x_size); + if (initialBoard.getColor(loc) == C_EMPTY) { // Filter out initial poses + randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); + } } } + assert(randomMoves.size() == initialBoard.numLegalMoves); + int movesCount = 0; int blackWinsCount = 0; int whiteWinsCount = 0; int drawsCount = 0; + int groundingCount = 0; auto moveRecords = vector(); for (int n = 0; n < gamesCount; n++) { rand.shuffle(randomMoves); + moveRecords.clear(); auto board = Board(initialBoard.x_size, initialBoard.y_size, rules); + Loc lastLoc = Board::NULL_LOC; + + int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * initialBoard.numLegalMoves; Player pla = P_BLACK; - for (Loc loc : randomMoves) { - if (board.isLegal(loc, pla, suicideAllowed, false)) { - Board::MoveRecord moveRecord = board.playMoveRecorded(loc, pla); + for(size_t index = 0; index < randomMoves.size(); index++) { + lastLoc = moveRecords.size() >= tryGroundingAfterMove ? Board::PASS_LOC : randomMoves[index]; + + if (board.isLegal(lastLoc, pla, suicideAllowed, false)) { + Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); movesCount++; - if (checkRollback) { - moveRecords.push_back(moveRecord); - } + moveRecords.push_back(moveRecord); pla = getOpp(pla); } + + if (lastLoc == Board::PASS_LOC) { + groundingCount++; + int scoreDiff; + int oppScoreIfGrounding; + Player lastPla = moveRecords.back().pla; + if (lastPla == P_BLACK) { + scoreDiff = board.numBlackCaptures - board.numWhiteCaptures; + oppScoreIfGrounding = board.whiteScoreIfBlackGrounds; + } else { + scoreDiff = board.numWhiteCaptures - board.numBlackCaptures; + oppScoreIfGrounding = board.blackScoreIfWhiteGrounds; + } + if (scoreDiff != oppScoreIfGrounding) { + Global::fatalError("scoreDiff (" + to_string(scoreDiff) + ") == oppScoreIfGrounding (" + to_string(oppScoreIfGrounding) + ") check is failed. " + + "Sgf: " + moveRecordsToSgf(initialBoard, moveRecords)); + } + if (performExtraChecks) { + Board boardBeforeGrounding(board); + boardBeforeGrounding.undo(moveRecords.back()); + validateGrounding(boardBeforeGrounding, board, lastPla, moveRecords); + } + break; + } } - /*if (suicideAllowed) { + if (dotsGame && suicideAllowed && lastLoc != Board::PASS_LOC) { testAssert(0 == board.numLegalMoves); - }*/ + } if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { whiteWinsCount++; @@ -541,7 +962,7 @@ void runDotsStressTestsInternal(int x_size, int y_size, int gamesCount, float gr drawsCount++; } - if (checkRollback) { + if (performExtraChecks) { while (!moveRecords.empty()) { board.undo(moveRecords.back()); moveRecords.pop_back(); @@ -563,6 +984,7 @@ void runDotsStressTestsInternal(int x_size, int y_size, int gamesCount, float gr cout << " Black wins: " << blackWinsCount << " (" << static_cast(blackWinsCount) / gamesCount << ")" << endl; cout << " White wins: " << whiteWinsCount << " (" << static_cast(whiteWinsCount) / gamesCount << ")" << endl; cout << " Draws: " << drawsCount << " (" << static_cast(drawsCount) / gamesCount << ")" << endl; + cout << " Groundings: " << groundingCount << " (" << static_cast(groundingCount) / gamesCount << ")" << endl; } void Tests::runDotsStressTests() { @@ -572,13 +994,15 @@ void Tests::runDotsStressTests() { Board board = Board(39, 32, Rules::DEFAULT_DOTS); for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { - Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; + const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; board.playMoveAssumeLegal(Location::getLoc(x, y, board.x_size), pla); } } testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); testAssert(0 == board.numLegalMoves); - //runDotsStressTestsInternal(39, 32, 100000, 0.8f, 0.01f, 0.0f, true, false); - runDotsStressTestsInternal(39, 32, 10000, 0.8f, 0.01f, 0.0f, true, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, 0.5f, false, 0.8f, 1.0f, true); + + runDotsStressTestsInternal(39, 32, 100000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, false); } \ No newline at end of file diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index d147e7b0b..ec247dcd2 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -46,6 +46,14 @@ struct BoardWithMoveRecords { return board.wouldBeCapture(Location::getLoc(x, y, board.x_size), player); } + int getWhiteScore() const { + return board.numBlackCaptures - board.numWhiteCaptures; + } + + int getBlackScore() const { + return -getWhiteScore(); + } + void undo() const { board.undo(moveRecords.back()); moveRecords.pop_back(); From 91bd12dbf85114872aba1933f57b0fd6f8f1fbf9 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:39:42 +0100 Subject: [PATCH 07/42] Refine `Board` and `BoardHistory` constructors It makes them consistent to current rules (Go or Go Game) and allows getting rid of unnecessary allocations --- cpp/book/book.cpp | 10 ++-- cpp/command/analysis.cpp | 6 +-- cpp/command/benchmark.cpp | 5 +- cpp/command/demoplay.cpp | 2 +- cpp/command/evalsgf.cpp | 9 ++-- cpp/command/genbook.cpp | 6 +-- cpp/command/gtp.cpp | 25 ++++----- cpp/command/misc.cpp | 11 ++-- cpp/command/startposes.cpp | 12 ++--- cpp/command/writetrainingdata.cpp | 8 +-- cpp/dataio/sgf.cpp | 54 +++++++++++-------- cpp/dataio/sgf.h | 7 +-- cpp/dataio/trainingwrite.cpp | 12 ++--- cpp/dataio/trainingwrite.h | 4 +- cpp/game/board.cpp | 12 ++--- cpp/game/board.h | 4 +- cpp/game/boardhistory.cpp | 88 +++++++++++++++++-------------- cpp/game/boardhistory.h | 5 +- cpp/program/play.cpp | 38 +++++++------ cpp/program/play.h | 4 +- cpp/program/playutils.cpp | 3 +- cpp/search/asyncbot.cpp | 5 +- cpp/search/asyncbot.h | 3 +- cpp/search/search.cpp | 16 +++--- cpp/search/search.h | 4 +- cpp/tests/testboardbasic.cpp | 59 +++++++++++---------- cpp/tests/testbook.cpp | 2 +- cpp/tests/testnnevalcanary.cpp | 24 +++------ cpp/tests/testnninputs.cpp | 46 +++++++--------- cpp/tests/testrules.cpp | 55 ++++++++++++++----- cpp/tests/testscore.cpp | 2 +- cpp/tests/testsearchcommon.cpp | 4 +- cpp/tests/testsearchmisc.cpp | 8 +-- cpp/tests/testsearchnonn.cpp | 7 +-- cpp/tests/testsearchv3.cpp | 8 +-- cpp/tests/testsearchv8.cpp | 40 ++++---------- cpp/tests/testsearchv9.cpp | 2 +- cpp/tests/testsgf.cpp | 24 ++++----- cpp/tests/testtrainingwrite.cpp | 35 ++++++------ 39 files changed, 337 insertions(+), 332 deletions(-) diff --git a/cpp/book/book.cpp b/cpp/book/book.cpp index 72e3fd5c6..a51ad8116 100644 --- a/cpp/book/book.cpp +++ b/cpp/book/book.cpp @@ -131,8 +131,8 @@ static Hash128 getExtraPosHash(const Board& board) { } void BookHash::getHashAndSymmetry(const BoardHistory& hist, int repBound, BookHash& hashRet, int& symmetryToAlignRet, vector& symmetriesRet, int bookVersion) { - Board boardsBySym[SymmetryHelpers::NUM_SYMMETRIES]; - BoardHistory histsBySym[SymmetryHelpers::NUM_SYMMETRIES]; + vector boardsBySym; + vector histsBySym; Hash128 accums[SymmetryHelpers::NUM_SYMMETRIES]; // Make sure the book all matches orientation for rectangular boards. @@ -143,8 +143,8 @@ void BookHash::getHashAndSymmetry(const BoardHistory& hist, int repBound, BookHa SymmetryHelpers::NUM_SYMMETRIES_WITHOUT_TRANSPOSE : SymmetryHelpers::NUM_SYMMETRIES; for(int symmetry = 0; symmetry < numSymmetries; symmetry++) { - boardsBySym[symmetry] = SymmetryHelpers::getSymBoard(hist.initialBoard,symmetry); - histsBySym[symmetry] = BoardHistory(boardsBySym[symmetry], hist.initialPla, hist.rules, hist.initialEncorePhase); + boardsBySym.emplace_back(SymmetryHelpers::getSymBoard(hist.initialBoard,symmetry)); + histsBySym.emplace_back(boardsBySym[symmetry], hist.initialPla, hist.rules, hist.initialEncorePhase); accums[symmetry] = Hash128(); } @@ -2621,7 +2621,7 @@ int64_t Book::exportToHtmlDir( const int symmetry = 0; SymBookNode symNode(node, symmetry); - BoardHistory hist; + BoardHistory hist(Rules::DEFAULT_GO); vector moveHistory; bool suc = symNode.getBoardHistoryReachingHere(hist,moveHistory); if(!suc) { diff --git a/cpp/command/analysis.cpp b/cpp/command/analysis.cpp index a6b5b6c30..5fa40e97a 100644 --- a/cpp/command/analysis.cpp +++ b/cpp/command/analysis.cpp @@ -430,7 +430,8 @@ int MainCmds::analysis(const vector& args) { vector bots; for(int threadIdx = 0; threadIdxsetCopyOfExternalPatternBonusTable(patternBonusTable); bot->setExternalEvalCache(evalCache); threads.push_back(std::thread(analysisLoopProtected,bot,threadIdx)); @@ -1111,8 +1112,7 @@ int MainCmds::analysis(const vector& args) { continue; } - - Board board(boardXSize,boardYSize); + Board board(boardXSize,boardYSize,rules); for(int i = 0; i& args) { } static void warmStartNNEval(const CompactSgf& sgf, Logger& logger, const SearchParams& params, NNEvaluator* nnEval, Rand& seedRand) { - Board board(sgf.xSize,sgf.ySize); + const Rules rules = Rules(sgf.isDots); + Board board(sgf.xSize,sgf.ySize,rules); Player nextPla = P_BLACK; - BoardHistory hist(board,nextPla,Rules(),0); + BoardHistory hist(board,nextPla,rules,0); SearchParams thisParams = params; thisParams.numThreads = 1; thisParams.maxVisits = 5; diff --git a/cpp/command/demoplay.cpp b/cpp/command/demoplay.cpp index 09b826836..987035d53 100644 --- a/cpp/command/demoplay.cpp +++ b/cpp/command/demoplay.cpp @@ -110,7 +110,7 @@ static void initializeDemoGame(Board& board, BoardHistory& hist, Player& pla, Ra const int size = sizes[rand.nextUInt(sizeFreqs,numSizes)]; - board = Board(size,size); + board = Board(size,size, Rules::DEFAULT_GO); pla = P_BLACK; hist.clear(board,pla,Rules::getTrompTaylorish(),0); bot->setPosition(pla,board,hist); diff --git a/cpp/command/evalsgf.cpp b/cpp/command/evalsgf.cpp index 29c534ba4..9de38853c 100644 --- a/cpp/command/evalsgf.cpp +++ b/cpp/command/evalsgf.cpp @@ -180,12 +180,13 @@ int MainCmds::evalsgf(const vector& args) { Rules defaultRules = Rules::getDefaultOrTrompTaylorish(sgf->isDots); Player perspective = Setup::parseReportAnalysisWinrates(cfg,P_BLACK); - Board board; + Board board(defaultRules); Player nextPla; - BoardHistory hist; + BoardHistory hist(defaultRules); auto setUpBoardUsingRules = [&board,&nextPla,&hist,overrideKomi,&sgf,&extraMoves](const Rules& initialRules, int moveNum) { - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + board = hist.initialBoard; vector& moves = sgf->moves; if(!isnan(overrideKomi)) { @@ -379,7 +380,7 @@ int MainCmds::evalsgf(const vector& args) { continue; } - AsyncBot* bot = new AsyncBot(params, nnEval, humanEval, &logger, searchRandSeed); + AsyncBot* bot = new AsyncBot(params, nnEval, humanEval, &logger, searchRandSeed, initialRules); bot->setPosition(nextPla,board,hist); if(hintLoc != "") { diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index 69c10ee10..079e65f98 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -357,9 +357,9 @@ int MainCmds::genbook(const vector& args) { std::map expandBonusByHash; std::map visitsRequiredByHash; std::map branchRequiredByHash; - Board bonusInitialBoard; + Board bonusInitialBoard(rules); Player bonusInitialPla; - bonusInitialBoard = Board(boardSizeX,boardSizeY); + bonusInitialBoard = Board(boardSizeX,boardSizeY,rules); bonusInitialPla = P_BLACK; for(const std::string& bonusFile: bonusFiles) { @@ -1534,7 +1534,7 @@ int MainCmds::writebook(const vector& args) { std::map expandBonusByHash; std::map visitsRequiredByHash; std::map branchRequiredByHash; - Board bonusInitialBoard; + Board bonusInitialBoard(rules); Player bonusInitialPla; maybeParseBonusFile( diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index a11b73128..138e17473 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -415,7 +415,7 @@ struct GTPEngine { isGenmoveParams(true), bTimeControls(), wTimeControls(), - initialBoard(), + initialBoard(initialRules), initialPla(P_BLACK), moveHistory(), recentWinLossValues(), @@ -550,11 +550,11 @@ struct GTPEngine { else searchRandSeed = Global::uint64ToString(seedRand.nextUInt64()); - bot = new AsyncBot(genmoveParams, nnEval, humanEval, &logger, searchRandSeed); + bot = new AsyncBot(genmoveParams, nnEval, humanEval, &logger, searchRandSeed, currentRules); bot->setCopyOfExternalPatternBonusTable(patternBonusTable); isGenmoveParams = true; - Board board(boardXSize,boardYSize); + Board board(boardXSize,boardYSize,currentRules); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); vector newMoveHistory; @@ -587,7 +587,7 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); int newXSize = bot->getRootBoard().x_size; int newYSize = bot->getRootBoard().y_size; - Board board(newXSize,newYSize); + Board board(newXSize,newYSize,currentRules); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); vector newMoveHistory; @@ -599,7 +599,7 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); int newXSize = bot->getRootBoard().x_size; int newYSize = bot->getRootBoard().y_size; - Board board(newXSize,newYSize); + Board board(newXSize,newYSize,currentRules); bool suc = board.setStonesFailIfNoLibs(initialStones); if(!suc) return false; @@ -1304,7 +1304,7 @@ struct GTPEngine { void placeFixedHandicap(int n, string& response, bool& responseIsError) { int xSize = bot->getRootBoard().x_size; int ySize = bot->getRootBoard().y_size; - Board board(xSize,ySize); + Board board(xSize,ySize,currentRules); try { PlayUtils::placeFixedHandicap(board,n); } @@ -1355,7 +1355,7 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); - Board board(xSize,ySize); + Board board(xSize,ySize,currentRules); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); double extraBlackTemperature = 0.25; @@ -3175,9 +3175,10 @@ int MainCmds::gtp(const vector& args) { } else { vector locs; - int xSize = engine->bot->getRootBoard().x_size; - int ySize = engine->bot->getRootBoard().y_size; - Board board(xSize,ySize); + const Board* rootBoard = &engine->bot->getRootBoard(); + int xSize = rootBoard->x_size; + int ySize = rootBoard->y_size; + Board board(xSize,ySize,rootBoard->rules); for(int i = 0; i& args) { else { Board sgfInitialBoard; Player sgfInitialNextPla; - BoardHistory sgfInitialHist; Rules sgfRules; Board sgfBoard; Player sgfNextPla; @@ -3350,7 +3350,8 @@ int MainCmds::gtp(const vector& args) { } } - sgf->setupInitialBoardAndHist(sgfRules, sgfInitialBoard, sgfInitialNextPla, sgfInitialHist); + BoardHistory sgfInitialHist = sgf->setupInitialBoardAndHist(sgfRules, sgfInitialNextPla); + sgfInitialBoard = sgfInitialHist.initialBoard; sgfInitialHist.setInitialTurnNumber(sgfInitialBoard.numStonesOnBoard()); //Should give more accurate temperaure and time control behavior sgfBoard = sgfInitialBoard; sgfNextPla = sgfInitialNextPla; diff --git a/cpp/command/misc.cpp b/cpp/command/misc.cpp index 001bc5b3b..e118b80e1 100644 --- a/cpp/command/misc.cpp +++ b/cpp/command/misc.cpp @@ -209,9 +209,9 @@ int MainCmds::evalrandominits(const vector& args) { Rand gameRand; while(true) { - Board board(19,19); - Player pla = P_BLACK; Rules rules = Rules::parseRules("japanese"); + Board board(19,19,rules); + Player pla = P_BLACK; BoardHistory hist(board,pla,rules,0); int numInitialMovesToPlay = (int)gameRand.nextUInt(200); double temperature = 1.0; @@ -322,11 +322,10 @@ int MainCmds::searchentropyanalysis(const vector& args) { std::unique_ptr sgfObj = CompactSgf::parse(sgf); for(int turnIdx = 0; turnIdx < sgfObj->moves.size(); turnIdx++) { - Board board; Player pla; - BoardHistory hist; - Rules initialRules; - sgfObj->setupInitialBoardAndHist(initialRules, board, pla, hist); + Rules rules = sgfObj->getRulesOrFailAllowUnspecified(Rules::getSimpleTerritory()); + BoardHistory hist = sgfObj->setupInitialBoardAndHist(rules, pla); + Board& board = hist.initialBoard; for(int i = 0; i < turnIdx; i++) { Loc moveLoc = sgfObj->moves[i].loc; diff --git a/cpp/command/startposes.cpp b/cpp/command/startposes.cpp index 6ee18b225..1af2904f0 100644 --- a/cpp/command/startposes.cpp +++ b/cpp/command/startposes.cpp @@ -436,11 +436,11 @@ int MainCmds::samplesgfs(const vector& args) { else { string fileName = sgf.fileName; CompactSgf compactSgf(sgf); - Board board; + Player nextPla; - BoardHistory hist; Rules rules = compactSgf.getRulesOrFailAllowUnspecified(Rules::getSimpleTerritory()); - compactSgf.setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = compactSgf.setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; if(valueFluctuationMakeKomiFair) { Rand rand; @@ -1386,10 +1386,10 @@ int MainCmds::dataminesgfs(const vector& args) { //Don't use the SGF rules - randomize them for a bit more entropy Rules rules = gameInit->createRules(); - Board board; Player nextPla; - BoardHistory hist; - sgf.setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = sgf.setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; + if(!gameInit->isAllowedBSize(board.x_size,board.y_size)) { numFilteredSgfs.fetch_add(1); return; diff --git a/cpp/command/writetrainingdata.cpp b/cpp/command/writetrainingdata.cpp index 5892aded3..e397a8976 100644 --- a/cpp/command/writetrainingdata.cpp +++ b/cpp/command/writetrainingdata.cpp @@ -1372,11 +1372,13 @@ int MainCmds::writetrainingdata(const vector& args) { // No friendly pass since we want to complete consistent with strict rules rules.friendlyPassOk = false; - Board board; + // TODO: Fix construction of board and hist + Board board(Rules::DEFAULT_GO); Player nextPla; - BoardHistory hist; + BoardHistory hist(Rules::DEFAULT_GO); try { - sgf->setupInitialBoardAndHist(rules, board, nextPla, hist); + hist = sgf->setupInitialBoardAndHist(rules, nextPla); + board = hist.initialBoard; } catch(const StringError& e) { logger.write("Bad initial setup in sgf " + fileName + " " + e.what()); diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 95a442fd9..715352a28 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -136,8 +136,9 @@ static void writeSgfLoc(ostream& out, Loc loc, int xSize, int ySize) { out << chars[y]; } -static Rules getRulesFromSgf(const bool dotsGame, const SgfNode& rootNode, const int xSize, const int ySize, const Rules* defaultRules) { +static Rules getRulesFromSgf(const SgfNode& rootNode, const int xSize, const int ySize, const Rules* defaultRules) { Rules rules; + const bool dotsGame = rootNode.getIsDotsGame(); if (defaultRules == nullptr || rootNode.hasProperty("RU")) { rules = rootNode.getRulesFromRUTagOrFail(dotsGame); } else { @@ -419,10 +420,7 @@ static void checkNonEmpty(const vector>& nodes) { } bool Sgf::isDotsGame() const { - if(!nodes[0]->hasProperty("GM")) - return false; - const string& s = nodes[0]->getSingleProperty("GM"); - return s == "40"; + return nodes[0]->getIsDotsGame(); } XYSize Sgf::getXYSize() const { @@ -478,6 +476,11 @@ float SgfNode::getKomiOrFail() const { return getKomiOrDefault(0.0f); } +bool SgfNode::getIsDotsGame() const { + if(!hasProperty("GM")) return false; + return getSingleProperty("GM") == "40"; +} + float SgfNode::getKomiOrDefault(float defaultKomi) const { //Default, if SGF doesn't specify if(!hasProperty("KM")) @@ -548,7 +551,7 @@ Rules Sgf::getRulesOrFail() const { checkNonEmpty(nodes); const XYSize size = getXYSize(); - return getRulesFromSgf(isDotsGame(), *nodes[0], size.x, size.y, nullptr); + return getRulesFromSgf(*nodes[0], size.x, size.y, nullptr); } Player Sgf::getSgfWinner() const { @@ -1497,7 +1500,7 @@ static std::unique_ptr maybeParseSgf(const string& str, size_t& pos) { && handicap >= 2 && handicap <= 9 ) { - Board board(19,19); + Board board(rootSgf->getRulesOrFail()); PlayUtils::placeFixedHandicap(board, handicap); // Older fox sgfs used handicaps with side stones on the north and south rather than east and west if(handicap == 6 || handicap == 7) { @@ -1712,11 +1715,11 @@ bool CompactSgf::hasRules() const { } Rules CompactSgf::getRulesOrFail() const { - return getRulesFromSgf(isDots, rootNode, xSize, ySize, nullptr); + return getRulesFromSgf(rootNode, xSize, ySize, nullptr); } Rules CompactSgf::getRulesOrFailAllowUnspecified(const Rules& defaultRules) const { - return getRulesFromSgf(isDots, rootNode, xSize, ySize, &defaultRules); + return getRulesFromSgf(rootNode, xSize, ySize, &defaultRules); } Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function f) const { @@ -1771,8 +1774,7 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function 0) nextPla = moves[0].pla; - board = Board(xSize,ySize,initialRules); + auto board = Board(xSize,ySize,initialRules); if (initialRules.startPos == Rules::START_POS_EMPTY) { bool suc = board.setStonesFailIfNoLibs(placements); if(!suc) throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); } - hist = BoardHistory(board,nextPla,initialRules,0); + BoardHistory hist = BoardHistory(board,nextPla,initialRules,0); if (int numStonesOnBoard = board.numStonesOnBoard(); hist.initialTurnNumber < numStonesOnBoard) hist.initialTurnNumber = numStonesOnBoard; + return hist; } void CompactSgf::playMovesAssumeLegal(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const { @@ -1838,14 +1841,23 @@ void CompactSgf::playMovesTolerant(Board& board, Player& nextPla, BoardHistory& } } -void CompactSgf::setupBoardAndHistAssumeLegal(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const { - setupInitialBoardAndHist(initialRules, board, nextPla, hist); - playMovesAssumeLegal(board, nextPla, hist, turnIdx); +std::pair CompactSgf::setupBoardAndHistAssumeLegal(const Rules& initialRules, Player& nextPla, int64_t turnIdx) + const { + BoardHistory hist = setupInitialBoardAndHist(initialRules, nextPla); + Board boardWithMoves(hist.initialBoard); + playMovesAssumeLegal(boardWithMoves, nextPla, hist, turnIdx); + return std::make_pair(hist, boardWithMoves); } -void CompactSgf::setupBoardAndHistTolerant(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const { - setupInitialBoardAndHist(initialRules, board, nextPla, hist); - playMovesTolerant(board, nextPla, hist, turnIdx, preventEncore); +std::pair CompactSgf::setupBoardAndHistTolerant( + const Rules& initialRules, + Player& nextPla, + int64_t turnIdx, + bool preventEncore) const { + BoardHistory hist = setupInitialBoardAndHist(initialRules, nextPla); + Board boardWithMoves(hist.initialBoard); + playMovesTolerant(boardWithMoves, nextPla, hist, turnIdx, preventEncore); + return std::make_pair(hist, boardWithMoves); } @@ -1968,7 +1980,7 @@ void WriteSgf::writeSgf( int xSize = initialBoard.x_size; int ySize = initialBoard.y_size; out << "(;FF[4]"; - out << "GM[" << (initialBoard.rules.isDots ? "40" : "1") << "]"; + out << "GM[" << (rules.isDots ? "40" : "1") << "]"; if(xSize == ySize) out << "SZ[" << xSize << "]"; else @@ -1976,7 +1988,7 @@ void WriteSgf::writeSgf( out << "PB[" << bName << "]"; out << "PW[" << wName << "]"; - if (!initialBoard.rules.isDots) { + if (!rules.isDots) { if(gameData != NULL) { out << "HA[" << gameData->handicapForSgf << "]"; } diff --git a/cpp/dataio/sgf.h b/cpp/dataio/sgf.h index 567791154..53b855653 100644 --- a/cpp/dataio/sgf.h +++ b/cpp/dataio/sgf.h @@ -42,6 +42,7 @@ struct SgfNode { Rules getRulesFromRUTagOrFail(bool isDots) const; Player getSgfWinner() const; float getKomiOrFail() const; + bool getIsDotsGame() const; float getKomiOrDefault(float defaultKomi) const; std::string getPlayerName(Player pla) const; @@ -243,12 +244,12 @@ struct CompactSgf { Rules getRulesOrFailAllowUnspecified(const Rules& defaultRules) const; Rules getRulesOrWarn(const Rules& defaultRules, std::function f) const; - void setupInitialBoardAndHist(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist) const; + BoardHistory setupInitialBoardAndHist(const Rules& initialRules, Player& nextPla) const; void playMovesAssumeLegal(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const; - void setupBoardAndHistAssumeLegal(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx) const; + std::pair setupBoardAndHistAssumeLegal(const Rules& initialRules, Player& nextPla, int64_t turnIdx) const; //These throw a StringError upon illegal move. void playMovesTolerant(Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const; - void setupBoardAndHistTolerant(const Rules& initialRules, Board& board, Player& nextPla, BoardHistory& hist, int64_t turnIdx, bool preventEncore) const; + std::pair setupBoardAndHistTolerant(const Rules& initialRules, Player& nextPla, int64_t turnIdx, bool preventEncore) const; }; namespace WriteSgf { diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index 34d9e778c..5114177ac 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -18,9 +18,9 @@ ValueTargets::~ValueTargets() //------------------------------------------------------------------------------------- -SidePosition::SidePosition() +SidePosition::SidePosition(const Rules& rules) :board(), - hist(), + hist(rules), pla(P_BLACK), unreducedNumVisits(), policyTarget(), @@ -59,15 +59,15 @@ SidePosition::~SidePosition() //------------------------------------------------------------------------------------- -FinishedGameData::FinishedGameData() +FinishedGameData::FinishedGameData(const Rules& rules) :bName(), wName(), bIdx(0), wIdx(0), - startBoard(), - startHist(), - endHist(), + startBoard(rules), + startHist(rules), + endHist(rules), startPla(P_BLACK), gameHash(), diff --git a/cpp/dataio/trainingwrite.h b/cpp/dataio/trainingwrite.h index 253c13258..23d043603 100644 --- a/cpp/dataio/trainingwrite.h +++ b/cpp/dataio/trainingwrite.h @@ -57,7 +57,7 @@ struct SidePosition { Player playoutDoublingAdvantagePla; double playoutDoublingAdvantage; - SidePosition(); + explicit SidePosition(const Rules& rules); SidePosition(const Board& board, const BoardHistory& hist, Player pla, int numNeuralNetChangesSoFar); ~SidePosition(); }; @@ -128,7 +128,7 @@ struct FinishedGameData { static constexpr int MODE_HINTFORK = 6; static constexpr int MODE_ASYM = 7; - FinishedGameData(); + explicit FinishedGameData(const Rules& rules); ~FinishedGameData(); void printDebug(std::ostream& out) const; diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 52f04769d..44e0b839c 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -119,17 +119,13 @@ Board::Base::Base(Player newPla, is_real = isReal; } -Board::Board() -{ - init(DEFAULT_LEN_X, DEFAULT_LEN_Y, Rules()); -} +Board::Board() : Board(Rules::DEFAULT_GO) {} -Board::Board(int x, int y) -{ - init(x, y, Rules()); +Board::Board(const Rules& rules) { + init(DEFAULT_LEN_X, DEFAULT_LEN_Y, rules); } -Board::Board(int x, int y, const Rules& rules) { +Board::Board(const int x, const int y, const Rules& rules) { init(x, y, rules); } diff --git a/cpp/game/board.h b/cpp/game/board.h index 1b386f676..a2e029ed9 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -229,8 +229,8 @@ struct Board //Constructors--------------------------------- Board(); //Create Board of size (DEFAULT_LEN,DEFAULT_LEN) - Board(int x, int y); // Create Board of size (x,y) - Board(int x, int y, const Rules& rules); + explicit Board(const Rules& rules); + Board(int x, int y, const Rules& rules); // Create Board of size (x,y) with the specified Rules Board(const Board& other); Board& operator=(const Board&) = default; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index ae6fdd4c4..7a10a8889 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -23,39 +23,46 @@ static Hash128 getKoHashAfterMoveNonEncore(const Rules& rules, Hash128 posHashAf // return posHashAfterMove ^ koRecapBlockHashAfterMove; // } - -BoardHistory::BoardHistory() - :rules(), - moveHistory(), - preventEncoreHistory(), - koHashHistory(), - firstTurnIdxWithKoHistory(0), - initialBoard(), - initialPla(P_BLACK), - initialEncorePhase(0), - initialTurnNumber(0), - assumeMultipleStartingBlackMovesAreHandicap(false), - whiteHasMoved(false), - overrideNumHandicapStones(-1), - recentBoards(), - currentRecentBoardIdx(0), - presumedNextMovePla(P_BLACK), - consecutiveEndingPasses(0), - hashesBeforeBlackPass(),hashesBeforeWhitePass(), - encorePhase(0), - numTurnsThisPhase(0), - numApproxValidTurnsThisPhase(0), - numConsecValidTurnsThisGame(0), - koRecapBlockHash(), - koCapturesInEncore(), - whiteBonusScore(0.0f), - whiteHandicapBonusScore(0.0f), - hasButton(false), - isPastNormalPhaseEnd(false), - isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), - isScored(false),isNoResult(false),isResignation(false) -{ - if (!rules.isDots) { +BoardHistory::BoardHistory() : BoardHistory(Rules::DEFAULT_GO) {} + +BoardHistory::BoardHistory(const Rules& rules) + : rules(rules), + moveHistory(), + preventEncoreHistory(), + koHashHistory(), + firstTurnIdxWithKoHistory(0), + initialBoard(rules), + initialPla(P_BLACK), + initialEncorePhase(0), + initialTurnNumber(0), + assumeMultipleStartingBlackMovesAreHandicap(false), + whiteHasMoved(false), + overrideNumHandicapStones(-1), + currentRecentBoardIdx(0), + presumedNextMovePla(P_BLACK), + consecutiveEndingPasses(0), + hashesBeforeBlackPass(), + hashesBeforeWhitePass(), + encorePhase(0), + numTurnsThisPhase(0), + numApproxValidTurnsThisPhase(0), + numConsecValidTurnsThisGame(0), + koRecapBlockHash(), + koCapturesInEncore(), + whiteBonusScore(0.0f), + whiteHandicapBonusScore(0.0f), + hasButton(false), + isPastNormalPhaseEnd(false), + isGameFinished(false), + winner(C_EMPTY), + finalWhiteMinusBlackScore(0.0f), + isScored(false), + isNoResult(false), + isResignation(false) { + for(int i = 0; i < NUM_RECENT_BOARDS; i++) { + recentBoards.emplace_back(rules); + } + if(!rules.isDots) { wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); superKoBanned.resize(Board::MAX_ARR_SIZE, false); koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); @@ -72,7 +79,7 @@ BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int e preventEncoreHistory(), koHashHistory(), firstTurnIdxWithKoHistory(0), - initialBoard(), + initialBoard(rules), initialPla(), initialEncorePhase(0), initialTurnNumber(0), @@ -97,6 +104,9 @@ BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int e isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), isScored(false),isNoResult(false),isResignation(false) { + for(int i = 0; i < NUM_RECENT_BOARDS; i++) { + recentBoards.emplace_back(rules); + } if (!rules.isDots) { wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); superKoBanned.resize(Board::MAX_ARR_SIZE, false); @@ -120,7 +130,6 @@ BoardHistory::BoardHistory(const BoardHistory& other) assumeMultipleStartingBlackMovesAreHandicap(other.assumeMultipleStartingBlackMovesAreHandicap), whiteHasMoved(other.whiteHasMoved), overrideNumHandicapStones(other.overrideNumHandicapStones), - recentBoards(), currentRecentBoardIdx(other.currentRecentBoardIdx), presumedNextMovePla(other.presumedNextMovePla), consecutiveEndingPasses(other.consecutiveEndingPasses), @@ -138,7 +147,7 @@ BoardHistory::BoardHistory(const BoardHistory& other) isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) { - std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; superKoBanned = other.superKoBanned; koRecapBlocked = other.koRecapBlocked; @@ -162,7 +171,7 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; @@ -205,7 +214,6 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept assumeMultipleStartingBlackMovesAreHandicap(other.assumeMultipleStartingBlackMovesAreHandicap), whiteHasMoved(other.whiteHasMoved), overrideNumHandicapStones(other.overrideNumHandicapStones), - recentBoards(), currentRecentBoardIdx(other.currentRecentBoardIdx), presumedNextMovePla(other.presumedNextMovePla), consecutiveEndingPasses(other.consecutiveEndingPasses), @@ -223,7 +231,7 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) { - std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; superKoBanned = other.superKoBanned; koRecapBlocked = other.koRecapBlocked; @@ -244,7 +252,7 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept assumeMultipleStartingBlackMovesAreHandicap = other.assumeMultipleStartingBlackMovesAreHandicap; whiteHasMoved = other.whiteHasMoved; overrideNumHandicapStones = other.overrideNumHandicapStones; - std::copy_n(other.recentBoards, NUM_RECENT_BOARDS, recentBoards); + recentBoards = other.recentBoards; currentRecentBoardIdx = other.currentRecentBoardIdx; presumedNextMovePla = other.presumedNextMovePla; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 6dfbf4546..2123e1410 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -38,8 +38,8 @@ struct BoardHistory { bool whiteHasMoved; int overrideNumHandicapStones; - static const int NUM_RECENT_BOARDS = 6; - Board recentBoards[NUM_RECENT_BOARDS]; + static constexpr int NUM_RECENT_BOARDS = 6; + std::vector recentBoards; int currentRecentBoardIdx; Player presumedNextMovePla; @@ -102,6 +102,7 @@ struct BoardHistory { bool isResignation; BoardHistory(); + explicit BoardHistory(const Rules& rules); ~BoardHistory(); BoardHistory(const Board& board, Player pla, const Rules& rules, int encorePhase); diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 4c3351863..88b787448 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -15,8 +15,8 @@ using namespace std; //---------------------------------------------------------------------------------------------------------- -InitialPosition::InitialPosition() - :board(),hist(),pla(C_EMPTY) +InitialPosition::InitialPosition(const Rules& rules) + :board(rules),hist(rules),pla(C_EMPTY) {} InitialPosition::InitialPosition(const Board& b, const BoardHistory& h, Player p, bool plainFork, bool sekiFork, bool hintFork, double tw) :board(b),hist(h),pla(p),isPlainFork(plainFork),isSekiFork(sekiFork),isHintFork(hintFork),trainingWeight(tw) @@ -482,6 +482,9 @@ int GameInitializer::getMaxBoardXSize() const { int GameInitializer::getMaxBoardYSize() const { return maxBoardYSize; } +bool GameInitializer::isDotsGame() const { + return dotsGame; +} Rules GameInitializer::createRules() { lock_guard lock(createGameMutex); @@ -1337,7 +1340,7 @@ FinishedGameData* Play::runGame( std::function checkForNewNNEval, std::function&, const std::vector&, const std::vector&, const Search*)> onEachMove ) { - FinishedGameData* gameData = new FinishedGameData(); + FinishedGameData* gameData = new FinishedGameData(startHist.rules); Board board(startBoard); BoardHistory hist(startHist); @@ -2197,10 +2200,11 @@ void Play::maybeForkGame( ASSERT_UNREACHABLE; } - Board board; + const Rules& rules = finishedGameData->startHist.rules; + Board board(rules); Player pla; - BoardHistory hist; - replayGameUpToMove(finishedGameData, moveIdx, finishedGameData->startHist.rules, board, hist, pla); + BoardHistory hist(rules); + replayGameUpToMove(finishedGameData, moveIdx, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) return; @@ -2285,9 +2289,9 @@ void Play::maybeSekiForkGame( Rules rules = finishedGameData->startHist.rules; rules = gameInit->randomizeScoringAndTaxRules(rules,gameRand); - Board board; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); replayGameUpToMove(finishedGameData, moveIdx, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) @@ -2317,12 +2321,13 @@ void Play::maybeHintForkGame( if(!hintFork) return; - Board board; + const Rules& rules = finishedGameData->startHist.rules; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); testAssert(finishedGameData->startHist.moveHistory.size() < 0x1FFFffff); int moveIdxToReplayTo = (int)finishedGameData->startHist.moveHistory.size(); - replayGameUpToMove(finishedGameData, moveIdxToReplayTo, finishedGameData->startHist.rules, board, hist, pla); + replayGameUpToMove(finishedGameData, moveIdxToReplayTo, rules, board, hist, pla); //Just in case if somehow the game is over now, don't actually do anything if(hist.isGameFinished) return; @@ -2406,9 +2411,10 @@ FinishedGameData* GameRunner::runGame( } } - Board board; + const Rules& rules = Rules::getDefault(gameInit->isDotsGame()); + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); ExtraBlackAndKomi extraBlackAndKomi; OtherGameProperties otherGameProps; if(playSettings.forSelfPlay) { @@ -2449,12 +2455,12 @@ FinishedGameData* GameRunner::runGame( Search* botB; Search* botW; if(botSpecB.botIdx == botSpecW.botIdx) { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, hist.rules.isDots); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, hist.rules); botW = botB; } else { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", hist.rules.isDots); - botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", hist.rules.isDots); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", hist.rules); + botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", hist.rules); } if(afterInitialization != nullptr) { if(botSpecB.botIdx == botSpecW.botIdx) { diff --git a/cpp/program/play.h b/cpp/program/play.h index d3935ac37..9b7fde7e9 100644 --- a/cpp/program/play.h +++ b/cpp/program/play.h @@ -24,7 +24,7 @@ struct InitialPosition { bool isHintFork; double trainingWeight; - InitialPosition(); + explicit InitialPosition(const Rules& rules); InitialPosition(const Board& board, const BoardHistory& hist, Player pla, bool isPlainFork, bool isSekiFork, bool isHintFork, double trainingWeight); ~InitialPosition(); }; @@ -119,6 +119,8 @@ class GameInitializer { int getMaxBoardXSize() const; int getMaxBoardYSize() const; + bool isDotsGame() const; + private: void initShared(ConfigParser& cfg, Logger& logger); void createGameSharedUnsynchronized( diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index 323e76dc6..8a03ff1ea 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -955,8 +955,7 @@ PlayUtils::BenchmarkResults PlayUtils::benchmarkSearchOnPositionsAndPrint( Board board; Player nextPla; - BoardHistory hist; - sgf.setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf.setupInitialBoardAndHist(initialRules, nextPla); int moveNum = 0; diff --git a/cpp/search/asyncbot.cpp b/cpp/search/asyncbot.cpp index 1a274c5f3..3e383f4aa 100644 --- a/cpp/search/asyncbot.cpp +++ b/cpp/search/asyncbot.cpp @@ -43,7 +43,8 @@ AsyncBot::AsyncBot( NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* l, - const string& randSeed + const string& randSeed, + const Rules& rules ) :search(NULL), controlMutex(),threadWaitingToSearch(),userWaitingForStop(),searchThread(), @@ -54,7 +55,7 @@ AsyncBot::AsyncBot( analyzeCallback(), searchBegunCallback() { - search = new Search(params,nnEval,humanEval,l,randSeed,false); + search = new Search(params,nnEval,humanEval,l,randSeed,rules); searchThread = std::thread(searchThreadLoop,this,l); } diff --git a/cpp/search/asyncbot.h b/cpp/search/asyncbot.h index ae94873ce..706574811 100644 --- a/cpp/search/asyncbot.h +++ b/cpp/search/asyncbot.h @@ -16,7 +16,8 @@ class AsyncBot { NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* logger, - const std::string& randSeed + const std::string& randSeed, + const Rules& rules ); ~AsyncBot(); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index d4e238c84..afe26c5c4 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -66,15 +66,15 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed) - :Search(params,nnEval,NULL,lg,rSeed,false) + :Search(params,nnEval,NULL,lg,rSeed,Rules::DEFAULT_GO) {} -Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed, const bool isDots) - :Search(params,nnEval,NULL,lg,rSeed,isDots) +Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed, const Rules& rules) + :Search(params,nnEval,NULL,lg,rSeed,rules) {} -Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed, const bool isDots) +Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed, const Rules& rules) :rootPla(P_BLACK), - rootBoard(), - rootHistory(), + rootBoard(rules), + rootHistory(rules), rootGraphHash(), rootHintLoc(Board::NULL_LOC), avoidMoveUntilByLocBlack(),avoidMoveUntilByLocWhite(),avoidMoveUntilRescaleRoot(false), @@ -127,8 +127,8 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, } assert(rootHistory.rules.isDots == rootBoard.isDots()); - rootHistory.clear(rootBoard,rootPla,Rules::getDefault(isDots),0); - if (!isDots) { + rootHistory.clear(rootBoard,rootPla,rules,0); + if (!rules.isDots) { rootKoHashTable = new KoHashTable(); rootKoHashTable->recompute(rootHistory); } diff --git a/cpp/search/search.h b/cpp/search/search.h index 5655be0cf..901a19506 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -192,14 +192,14 @@ struct Search { NNEvaluator* nnEval, Logger* logger, const std::string& randSeed, - bool isDots); + const Rules& rules); Search( SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const std::string& rSeed, - bool isDots); + const Rules& rules); ~Search(); Search(const Search&) = delete; diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index a88a682de..172c19299 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -739,7 +739,7 @@ After white //============================================================================ { const char* name = "Distance"; - Board board(17,12); + Board board(17,12,Rules::DEFAULT_GO); auto testDistance = [&](int x0, int y0, int x1, int y1) { out << "distance (" << x0 << "," << y0 << ") (" << x1 << "," << y1 << ") = " << @@ -1915,12 +1915,12 @@ void Tests::runBoardUndoTest() { int regularMoveCount = 0; auto run = [&](const Board& startBoard, bool multiStoneSuicideLegal) { static const int steps = 1000; - Board* boards = new Board[steps+1]; + vector boards; Board::MoveRecord records[steps]; - boards[0] = startBoard; + boards.push_back(startBoard); for(int n = 1; n <= steps; n++) { - boards[n] = boards[n-1]; + boards.push_back(boards[n-1]); Loc loc; Player pla; while(true) { @@ -1951,12 +1951,11 @@ void Tests::runBoardUndoTest() { testAssert(boardsSeemEqual(boards[n],board)); board.checkConsistency(); } - delete[] boards; }; - run(Board(19,19),true); - run(Board(4,4),true); - run(Board(4,4),false); + run(Board(19,19,Rules::DEFAULT_GO),true); + run(Board(4,4,Rules::DEFAULT_GO),true); + run(Board(4,4,Rules::DEFAULT_GO),false); ostringstream out; out << endl; @@ -1979,7 +1978,7 @@ suicideCount 79 void Tests::runBoardHandicapTest() { cout << "Running board handicap test" << endl; { - Board board = Board(19,19); + Board board = Board(19,19,Rules::DEFAULT_GO); Player nextPla = P_BLACK; Rules rules = Rules::parseRules("chinese"); BoardHistory hist(board,nextPla,rules,0); @@ -2001,9 +2000,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("chinese"); + const Rules rules = Rules::parseRules("chinese"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2020,9 +2019,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("aga"); + const Rules rules = Rules::parseRules("aga"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2039,9 +2038,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("aga"); + const Rules rules = Rules::parseRules("aga"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2070,9 +2069,9 @@ void Tests::runBoardHandicapTest() { } { - Board board = Board(19,19); Player nextPla = P_BLACK; - Rules rules = Rules::parseRules("chinese"); + const Rules rules = Rules::parseRules("chinese"); + Board board = Board(19,19,rules); BoardHistory hist(board,nextPla,rules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(true); @@ -2105,13 +2104,14 @@ void Tests::runBoardStressTest() { { Rand rand("runBoardStressTests"); static const int numBoards = 4; - Board boards[numBoards]; - boards[0] = Board(); - boards[1] = Board(9,16); - boards[2] = Board(13,7); - boards[3] = Board(4,4); + vector boards; + Rules rules = Rules::DEFAULT_GO; + boards.push_back(Board(Board::DEFAULT_LEN_X, Board::DEFAULT_LEN_Y, rules)); + boards.push_back(Board(9,16,rules)); + boards.push_back(Board(13,7,rules)); + boards.push_back(Board(4,4,rules)); bool multiStoneSuicideLegal[4] = {false,false,true,false}; - Board copies[numBoards]; + vector copies; Player pla = C_BLACK; int suicideCount = 0; int koBanCount = 0; @@ -2143,8 +2143,9 @@ void Tests::runBoardStressTest() { } } + copies.clear(); for(int i = 0; i placements; for(int i = 0; i<1000; i++) { Loc loc = Location::getLoc(rand.nextUInt(board.x_size),rand.nextUInt(board.y_size),board.x_size); @@ -2365,7 +2366,7 @@ Caps 4420 4335 placements.push_back(Move(loc,pla)); bool anyCaps = board.wouldBeCapture(loc,pla) || board.isSuicide(loc,pla); board.playMoveAssumeLegal(loc,pla); - Board copy(board.x_size,board.y_size); + Board copy(board.x_size,board.y_size,board.rules); bool suc = copy.setStonesFailIfNoLibs(placements); testAssert(suc == !anyCaps); copy.checkConsistency(); @@ -2379,7 +2380,7 @@ Caps 4420 4335 { Rand rand("runBoardSetStoneTests3"); for(int rep = 0; rep<1000; rep++) { - Board board(1 + rand.nextUInt(18), 1 + rand.nextUInt(18)); + Board board(1 + rand.nextUInt(18), 1 + rand.nextUInt(18),Rules::DEFAULT_GO); for(int i = 0; i<300; i++) { Loc loc = Location::getLoc(rand.nextUInt(board.x_size),rand.nextUInt(board.y_size),board.x_size); Color color = rand.nextBool(0.25) ? C_EMPTY : rand.nextBool(0.5) ? P_BLACK : P_WHITE; @@ -2395,7 +2396,7 @@ Caps 4420 4335 placements.push_back(Move(loc,color)); if(prevPlacedLocs.find(loc) != prevPlacedLocs.end()) { - Board copy(board.x_size,board.y_size); + Board copy(board.x_size,board.y_size,board.rules); bool suc = copy.setStonesFailIfNoLibs(placements); testAssert(!suc); placements.pop_back(); diff --git a/cpp/tests/testbook.cpp b/cpp/tests/testbook.cpp index d3b2bf6b4..deaa60b09 100644 --- a/cpp/tests/testbook.cpp +++ b/cpp/tests/testbook.cpp @@ -27,8 +27,8 @@ static void corruptNodes(Book* book, std::vector& newAndChangedNode void Tests::runBookTests() { cout << "Running book tests" << endl; - Board initialBoard(4,4); Rules rules = Rules::parseRules("japanese"); + Board initialBoard(4,4,rules); Player initialPla = P_BLACK; int repBound = 9; diff --git a/cpp/tests/testnnevalcanary.cpp b/cpp/tests/testnnevalcanary.cpp index 100ecd3a3..d207f6876 100644 --- a/cpp/tests/testnnevalcanary.cpp +++ b/cpp/tests/testnnevalcanary.cpp @@ -44,12 +44,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[pd];W[pp];B[dd];W[dp];B[qn];W[nq];B[cq];W[dq];B[cp];W[do];B[bn];W[cc];B[cd];W[dc];B[ec];W[eb];B[fb];W[fc];B[ed];W[gb];B[db];W[fa];B[cb];W[qo];B[pn];W[nc];B[qj];W[qc];B[qd];W[pc];B[od];W[nd];B[ne];W[me];B[mf];W[nf])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 18; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -79,12 +77,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[pd];W[pp];B[dd];W[dp];B[qn];W[nq];B[cq];W[dq];B[cp];W[do];B[bn];W[cc];B[cd];W[dc];B[ec];W[eb];B[fb];W[fc];B[ed];W[gb];B[db];W[fa];B[cb];W[qo];B[pn];W[nc];B[qj];W[qc];B[qd];W[pc];B[od];W[nd];B[ne];W[me];B[mf];W[nf])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 36; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -113,12 +109,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -148,12 +142,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); hist.setKomi(-7); MiscNNInputParams nnInputParams; @@ -180,12 +172,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;GM[1]FF[4]CA[UTF-8]AP[CGoban:3]ST[2]RU[Chinese]SZ[19]KM[7]PW[White]PB[Black];B[qd];W[dd];B[pp];W[dp];B[cf];W[fc];B[nd];W[nq];B[cq];W[dq];B[cp];W[cn];B[co];W[do];B[bn];W[cm];B[bm];W[cl];B[qn];W[pq];B[qq];W[qr];B[oq])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 23; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); hist.setKomi(21); MiscNNInputParams nnInputParams; @@ -212,12 +202,10 @@ void Tests::runCanaryTests(NNEvaluator* nnEval, int symmetry, bool print) { string sgfStr = "(;FF[4]GM[1]CA[UTF-8]RU[Japanese]KM[6]SZ[16:11];B[md];W[nh];B[dh];W[cd];B[lh];W[li];B[ki])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); int turnIdx = 7; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); MiscNNInputParams nnInputParams; NNResultBuf buf; diff --git a/cpp/tests/testnninputs.cpp b/cpp/tests/testnninputs.cpp index 47685f3f7..118cec1e3 100644 --- a/cpp/tests/testnninputs.cpp +++ b/cpp/tests/testnninputs.cpp @@ -128,12 +128,11 @@ void Tests::runNNInputsV3V4Tests() { for(int version = minVersion; version <= maxVersion; version++) { cout << "VERSION " << version << endl; - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = Rules::getTrompTaylorish(); initialRules = sgf->getRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; igetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; for(size_t i = 0; i= 6; size--) { - Board board = Board(size,size); Player nextPla = P_BLACK; vector rules = { @@ -525,7 +520,9 @@ xxx..xx for(int c = 0; cgetRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; int nnXLen = 6; int nnYLen = 6; @@ -637,12 +633,11 @@ xxx..xx if(version == 5) continue; cout << "VERSION " << version << endl; - Board board; Player nextPla; - BoardHistory hist; Rules initialRules; initialRules = sgf->getRulesOrFailAllowUnspecified(initialRules); - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; vector& moves = sgf->moves; int nnXLen = 13; @@ -813,9 +808,9 @@ o.xoo.x float* rowGlobal; allocateRows(version,nnXLen,nnYLen,numFeaturesBin,numFeaturesGlobal,rowBin,rowGlobal); - Board board = Board(9,1); Player nextPla = P_BLACK; Rules initialRules = Rules::getSimpleTerritory(); + Board board = Board(9,1,initialRules); BoardHistory hist(board,nextPla,initialRules,0); auto run = [&](bool inputsUseNHWC) { @@ -895,9 +890,9 @@ o.xoo.x }; for(int i = 0; igetRulesOrFailAllowUnspecified(rulesToUse); - sgf->setupInitialBoardAndHist(rules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(rules, nextPla); + Board& board = hist.initialBoard; int nnXLen = 9; int nnYLen = 9; diff --git a/cpp/tests/testrules.cpp b/cpp/tests/testrules.cpp index 7b1fc49c8..58ae09228 100644 --- a/cpp/tests/testrules.cpp +++ b/cpp/tests/testrules.cpp @@ -3463,13 +3463,11 @@ HASH: 5C26A060FA78FD93FFF559C72BD7C6A4 std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; - BoardHistory hist; Player nextPla = P_BLACK; int turnIdxToSetup = (int)sgf->moves.size(); Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdxToSetup); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdxToSetup); string expected = R"%%( HASH: EB867913318513FD9DE98EDE86AE8CE0 A B C D E F G H J K L M @@ -5272,20 +5270,49 @@ Last moves pass pass pass pass H7 G9 F9 H7 Rand baseRand(name); constexpr int numBoards = 6; - Board boards[numBoards] = { - Board(2,2), - Board(5,1), - Board(6,1), - Board(2,3), - Board(4,2), - }; for(int i = 0; i sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(defaultRules); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx);; hist.setKomi(overrideKomi); runBotOnPosition(bot,board,nextPla,hist,opts); } diff --git a/cpp/tests/testsearchmisc.cpp b/cpp/tests/testsearchmisc.cpp index 584cf26f2..fbcd98106 100644 --- a/cpp/tests/testsearchmisc.cpp +++ b/cpp/tests/testsearchmisc.cpp @@ -128,11 +128,9 @@ void Tests::runNNOnManyPoses(const string& modelFile, bool inputsNHWC, bool useN vector policyProbs; for(int turnIdx = 0; turnIdxmoves.size(); turnIdx++) { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); nnEval->evaluate(board,hist,nextPla,nnInputParams,buf,skipCache,includeOwnerMap); winProbs.push_back(buf.result->whiteWinProb); @@ -204,9 +202,7 @@ void Tests::runNNBatchingTest(const string& modelFile, bool inputsNHWC, bool use Rand rand("runNNBatchingTest"); std::unique_ptr sgf = CompactSgf::parse(sgfStr); for(int turnIdx = 0; turnIdxmoves.size(); turnIdx++) { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules; initialRules.koRule = rand.nextBool(0.5) ? Rules::KO_SIMPLE : rand.nextBool(0.5) ? Rules::KO_POSITIONAL : Rules::KO_SITUATIONAL; initialRules.scoringRule = rand.nextBool(0.5) ? Rules::SCORING_AREA : Rules::SCORING_TERRITORY; @@ -215,7 +211,7 @@ void Tests::runNNBatchingTest(const string& modelFile, bool inputsNHWC, bool use initialRules.hasButton = initialRules.scoringRule == Rules::SCORING_AREA && rand.nextBool(0.5); initialRules.whiteHandicapBonusRule = rand.nextBool(0.5) ? Rules::WHB_ZERO : rand.nextBool(0.5) ? Rules::WHB_N : Rules::WHB_N_MINUS_ONE; initialRules.komi = 7.5f + rand.nextInt(-10,10) * 0.5f; - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, turnIdx); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, turnIdx); items.push_back(NNBatchingTestItem(board,hist,nextPla)); } }; diff --git a/cpp/tests/testsearchnonn.cpp b/cpp/tests/testsearchnonn.cpp index f5cf38ad4..1aee45db5 100644 --- a/cpp/tests/testsearchnonn.cpp +++ b/cpp/tests/testsearchnonn.cpp @@ -898,7 +898,7 @@ xx......x Rand rand("noiseVisualize"); auto run = [&](int xSize, int ySize) { - Board board(xSize,ySize); + const Board board(xSize,ySize,Rules::DEFAULT_GO); int nnXLen = 19; int nnYLen = 19; float sum = 0.0; @@ -2534,9 +2534,10 @@ x.x.x std::map,int> boardSizeDistribution; for(int i = 0; i<100000; i++) { - Board board; + Rules rules = Rules::DEFAULT_GO; + Board board(rules); Player pla; - BoardHistory hist; + BoardHistory hist(rules); ExtraBlackAndKomi extraBlackAndKomi; OtherGameProperties otherGameProps; gameInit.createGame(board,pla,hist,extraBlackAndKomi,NULL,PlaySettings(),otherGameProps,NULL); diff --git a/cpp/tests/testsearchv3.cpp b/cpp/tests/testsearchv3.cpp index 3e24cd788..5817e8aa6 100644 --- a/cpp/tests/testsearchv3.cpp +++ b/cpp/tests/testsearchv3.cpp @@ -24,11 +24,9 @@ static void runOwnershipAndMisc(NNEvaluator* nnEval, NNEvaluator* nnEval11, NNEv string sgfStr = "(;FF[4]CA[UTF-8]KM[7.5];B[pp];W[pc];B[cd];W[dq];B[ed];W[pe];B[co];W[cp];B[do];W[fq];B[ck];W[qn];B[qo];W[pn];B[np];W[qj];B[jc];W[lc];B[je];W[lq];B[mq];W[lp];B[ek];W[qq];B[pq];W[ro];B[rp];W[qp];B[po];W[rq];B[rn];W[sp];B[rm];W[ql];B[on];W[om];B[nn];W[nm];B[mn];W[ip];B[mm])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 40); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 40); MiscNNInputParams nnInputParams; NNResultBuf buf; @@ -68,11 +66,9 @@ static void runOwnershipAndMisc(NNEvaluator* nnEval, NNEvaluator* nnEval11, NNEv string sgfStr = "(;FF[4]CA[UTF-8]SZ[11]KM[7.5];B[ci];W[ic];B[ih];W[hi];B[ii];W[ij];B[jj];W[gj];B[ik];W[di];B[hh];W[ch];B[dc];W[cc];B[cb];W[cd];B[eb];W[dd];B[ed];W[ee];B[fd];W[bb];B[ba];W[ab];B[gb];W[je];B[ib];W[jb];B[jc];W[jd];B[hc];W[id];B[dh];W[cg];B[dj];W[ei];B[bi];W[ia];B[hb];W[fg];B[hj];W[eh];B[ej])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 43); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 43); MiscNNInputParams nnInputParams; NNResultBuf buf; diff --git a/cpp/tests/testsearchv8.cpp b/cpp/tests/testsearchv8.cpp index 431ba103c..3f51ce991 100644 --- a/cpp/tests/testsearchv8.cpp +++ b/cpp/tests/testsearchv8.cpp @@ -22,11 +22,9 @@ static void runV8TestsSize9(NNEvaluator* nnEval, NNEvaluator* nnEval9, NNEvaluat string sgfStr = "(;FF[4]GM[1]SZ[9]HA[0]KM[7]RU[stonescoring];B[ef];W[ed];B[ge])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 3); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 3); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; @@ -58,11 +56,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[19]KM[6.5];B[dd];W[qd];B[pq];W[dp];B[oc];W[pe];B[fq];W[jp];B[ph];W[cf];B[ck])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 11); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 11); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; @@ -92,11 +88,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[19]KM[6.5];B[dd];W[qd];B[od];W[pq];B[dq];W[do];B[eo];W[oe])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams params = SearchParams::forTestsV1(); params.rootNumSymmetriesToSample = 8; @@ -127,11 +121,9 @@ static void runV8TestsRandomSym(NNEvaluator* nnEval, NNEvaluator* nnEval19Exact, string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[AGA]SZ[19]KM[7.0];B[dd];W[pd];B[dp];W[pp];B[qc];W[qd];B[pc];W[nc];B[nb])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams paramsA = SearchParams::forTestsV1(); SearchParams paramsB = SearchParams::forTestsV1(); @@ -528,11 +520,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) std::unique_ptr sgf = CompactSgf::parse(sgfStr); { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 24); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 24); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -543,11 +533,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 32); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 32); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -558,11 +546,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 124); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 124); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -583,11 +569,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) std::unique_ptr sgf = CompactSgf::parse(sgfStr); { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 29); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 29); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -598,11 +582,9 @@ static void runV8Tests(NNEvaluator* nnEval, Logger& logger) delete bot; } { - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFailAllowUnspecified(Rules::getTrompTaylorish()); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 83); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 83); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 200; params.antiMirror = true; @@ -626,11 +608,9 @@ static void runMoreV8Tests(NNEvaluator* nnEval, Logger& logger) string sgfStr = "(;GM[1]FF[4]CA[UTF-8]RU[Japanese]SZ[9]KM[0];B[dc];W[ef];B[df];W[de];B[dg];W[eg];B[eh];W[fh];B[ee])"; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board board; Player nextPla; - BoardHistory hist; Rules initialRules = sgf->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(initialRules, board, nextPla, hist, 8); + auto [hist, board] = sgf->setupBoardAndHistAssumeLegal(initialRules, nextPla, 8); SearchParams params = SearchParams::forTestsV1(); params.maxVisits = 20; diff --git a/cpp/tests/testsearchv9.cpp b/cpp/tests/testsearchv9.cpp index 9b4d3d687..57bb7cc12 100644 --- a/cpp/tests/testsearchv9.cpp +++ b/cpp/tests/testsearchv9.cpp @@ -554,9 +554,9 @@ oo...ooo { cout << "Single symmetry and full symmetry raw nets" << endl; - Board board(19,19); Player nextPla = P_BLACK; Rules rules = Rules::parseRules("Japanese"); + Board board(19,19,rules); rules.komi = -4; BoardHistory hist(board,nextPla,rules,0); hist.makeBoardMoveAssumeLegal(board,Location::ofString("A1",board),P_BLACK,NULL); diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index 99e8fba40..2c3c4271f 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -21,14 +21,13 @@ void Tests::runSgfTests() { out << "xSize " << sgf->xSize << endl; out << "ySize " << sgf->ySize << endl; out << "depth " << sgf->depth << endl; - Rules rules = sgf->getRulesOrFailAllowUnspecified(Rules::getDefaultOrTrompTaylorish(sgf->isDots)); + const Rules rules = sgf->getRulesOrFailAllowUnspecified(Rules::getDefaultOrTrompTaylorish(sgf->isDots)); out << "komi " << rules.komi << endl; - Board board; - BoardHistory hist; Player pla; - sgf->setupInitialBoardAndHist(rules,board,pla,hist); + const BoardHistory hist = sgf->setupInitialBoardAndHist(rules, pla); + const Board& board = hist.initialBoard; if (rules.startPos == Rules::START_POS_EMPTY) { out << "placements" << endl; @@ -49,26 +48,23 @@ void Tests::runSgfTests() { out << "pla " << PlayerIO::playerToString(pla) << endl; hist.printDebugInfo(out,board); - sgf->setupBoardAndHistAssumeLegal(rules,board,pla,hist,sgf->moves.size()); + auto [finalHist, finalBoard] = sgf->setupBoardAndHistAssumeLegal(rules, pla, sgf->moves.size()); out << "Final board hist " << endl; out << "pla " << PlayerIO::playerToString(pla) << endl; - hist.printDebugInfo(out,board); + finalHist.printDebugInfo(out,finalBoard); { //Test SGF writing roundtrip. //This is not exactly holding if there is pass for ko, but should be good in all other cases ostringstream out2; - WriteSgf::writeSgf(out2,"foo","bar",hist,NULL,false,false); + WriteSgf::writeSgf(out2,"foo","bar",finalHist,NULL,false,false); std::unique_ptr sgf2 = CompactSgf::parse(out2.str()); - Board board2; - BoardHistory hist2; - Rules rules2; Player pla2; - rules2 = sgf2->getRulesOrFail(); - sgf->setupBoardAndHistAssumeLegal(rules2,board2,pla2,hist2,sgf2->moves.size()); + const Rules rules2 = sgf2->getRulesOrFail(); + auto [hist2, board2] = sgf->setupBoardAndHistAssumeLegal(rules2, pla2, sgf2->moves.size()); testAssert(rules2 == rules); - testAssert(board2.pos_hash == board.pos_hash); - testAssert(hist2.moveHistory.size() == hist.moveHistory.size()); + testAssert(board2.pos_hash == finalBoard.pos_hash); + testAssert(hist2.moveHistory.size() == finalHist.moveHistory.size()); } }; diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index 856a0546a..d73e5b7ba 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -97,7 +97,7 @@ void Tests::runTrainingWriteTests() { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(boardXLen,boardYLen); + Board initialBoard(boardXLen,boardYLen,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -226,7 +226,7 @@ void Tests::runSelfplayInitTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -407,7 +407,7 @@ void Tests::runMoreSelfplayTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; if(testHint) { @@ -587,7 +587,7 @@ void Tests::runMoreSelfplayTestsWithNN(const string& modelFile) { botSpec.nnEval = nnEval; botSpec.baseParams = params; - Board initialBoard(11,11); + Board initialBoard(11,11,rules); Player initialPla = P_BLACK; int initialEncorePhase = 0; BoardHistory initialHist(initialBoard,initialPla,rules,initialEncorePhase); @@ -1009,10 +1009,9 @@ xxxxxxxx. vector moves = sgf->moves; Rules initialRules = Rules::parseRules("chinese"); - Board board; Player nextPla; - BoardHistory hist; - sgf->setupInitialBoardAndHist(initialRules, board, nextPla, hist); + BoardHistory hist = sgf->setupInitialBoardAndHist(initialRules, nextPla); + Board& board = hist.initialBoard; for(size_t i = 0; i(); startPosSample.initialTurnNumber = 0; @@ -1177,7 +1176,7 @@ xxxxxxxx. TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); Sgf::PositionSample startPosSample; - startPosSample.board = Board(9,9); + startPosSample.board = Board(9,9,rules); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 40; @@ -2249,7 +2248,7 @@ void Tests::runSelfplayStatTestsWithNN(const string& modelFile) { playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2301,7 +2300,7 @@ void Tests::runSelfplayStatTestsWithNN(const string& modelFile) { playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2455,7 +2454,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2510,7 +2509,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2564,7 +2563,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2619,7 +2618,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector({ Move(Location::getLoc(3,3,19),P_BLACK), @@ -2677,7 +2676,7 @@ oox.x.... playSettings.forSelfPlay = true; Sgf::PositionSample startPosSample; - startPosSample.board = Board(19,19); + startPosSample.board = Board(19,19,Rules::DEFAULT_GO); startPosSample.nextPla = P_BLACK; startPosSample.moves = std::vector(); startPosSample.initialTurnNumber = 0; @@ -2764,16 +2763,14 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { botSpec.baseParams = params; std::unique_ptr sgf = CompactSgf::parse(sgfStr); - Board initialBoard; Player initialPla; - BoardHistory initialHist; ExtraBlackAndKomi extraBlackAndKomi; extraBlackAndKomi.extraBlack = 0; extraBlackAndKomi.komiMean = rules.komi; extraBlackAndKomi.komiStdev = 0; int turnIdx = (int)sgf->moves.size(); - sgf->setupBoardAndHistAssumeLegal(rules,initialBoard,initialPla,initialHist,turnIdx); + auto [initialHist, initialBoard] = sgf->setupBoardAndHistAssumeLegal(rules, initialPla, turnIdx); bool doEndGameIfAllPassAlive = true; bool clearBotAfterSearch = true; From a54fb85d6bf70751289ef2894c7b2120b2a1ff4b Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 7 Sep 2025 13:26:19 +0200 Subject: [PATCH 08/42] Refine grounding utils in `BoardHistory` Introduce `whiteScoreIfGroundingAlive`, `isGroundingWinsGame` Remove `getAreaNow` because it's unused --- cpp/CMakeLists.txt | 1 + cpp/command/runtests.cpp | 3 +- cpp/game/board.cpp | 2 +- cpp/game/board.h | 2 +- cpp/game/boardhistory.cpp | 54 ++----------- cpp/game/boardhistory.h | 8 +- cpp/game/dotsboardhistory.cpp | 48 ++++++++++++ cpp/game/dotsfield.cpp | 52 ++++++++++-- cpp/neuralnet/nninputs.h | 36 +++++---- cpp/neuralnet/nninputsdots.cpp | 31 ++++---- cpp/search/searchhelpers.cpp | 5 ++ cpp/tests/testdotsbasic.cpp | 69 ++++++++++++++++ cpp/tests/testdotsextra.cpp | 139 +++++++++++++++------------------ cpp/tests/testdotsutils.cpp | 4 - cpp/tests/testdotsutils.h | 4 +- cpp/tests/tests.h | 3 +- 16 files changed, 285 insertions(+), 176 deletions(-) create mode 100644 cpp/game/dotsboardhistory.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4e0d41c2f..e242b702d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -235,6 +235,7 @@ add_executable(katago game/dotsfield.cpp game/rules.cpp game/boardhistory.cpp + game/dotsboardhistory.cpp game/graphhash.cpp dataio/sgf.cpp dataio/numpywrite.cpp diff --git a/cpp/command/runtests.cpp b/cpp/command/runtests.cpp index 26d154cd4..a4c17eb2c 100644 --- a/cpp/command/runtests.cpp +++ b/cpp/command/runtests.cpp @@ -32,13 +32,14 @@ int MainCmds::runtests(const vector& args) { Tests::runDotsFieldTests(); Tests::runDotsGroundingTests(); + Tests::runDotsBoardHistoryGroundingTests(); Tests::runDotsPosHashTests(); Tests::runDotsStartPosTests(); Tests::runDotsStressTests(); Tests::runDotsSymmetryTests(); - Tests::runDotsTerritoryTests(); + Tests::runDotsOwnershipTests(); Tests::runDotsCapturingTests(); BSearch::runTests(); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 44e0b839c..72b921c64 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -1892,7 +1892,7 @@ void Board::calculateArea( bool isMultiStoneSuicideLegal ) const { if (rules.isDots) { - calculateGroundingWhiteScore(result); + calculateOwnershipAndWhiteScore(result, C_EMPTY); return; } diff --git a/cpp/game/board.h b/cpp/game/board.h index a2e029ed9..5b251a3a8 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -363,7 +363,7 @@ struct Board bool isMultiStoneSuicideLegal ) const; - int calculateGroundingWhiteScore(Color* result) const; + int calculateOwnershipAndWhiteScore(Color* result, Color groundingPlayer) const; // Calculates the area (including non pass alive stones, safe and unsafe big territories) //However, strips out any "seki" regions. diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 7a10a8889..c260abc85 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -73,6 +73,8 @@ BoardHistory::BoardHistory(const Rules& rules) BoardHistory::~BoardHistory() {} +BoardHistory::BoardHistory(const Board& board) : BoardHistory(board, P_BLACK, board.rules, 0) {} + BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int ePhase) :rules(r), moveHistory(), @@ -608,10 +610,6 @@ float BoardHistory::currentSelfKomi(Player pla, double drawEquivalentWinsForWhit } } -int BoardHistory::countGroundingScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) { - return board.calculateGroundingWhiteScore(area); -} - int BoardHistory::countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { assert(rules.isDots == board.isDots() && !rules.isDots); @@ -717,23 +715,12 @@ void BoardHistory::setFinalScoreAndWinner(float score) { winner = C_EMPTY; } -void BoardHistory::getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { - if(rules.isDots) - countGroundingScoreWhiteMinusBlack(board,area); - else if(rules.scoringRule == Rules::SCORING_AREA) - countAreaScoreWhiteMinusBlack(board,area); - else if(rules.scoringRule == Rules::SCORING_TERRITORY) - countTerritoryAreaScoreWhiteMinusBlack(board,area); - else - ASSERT_UNREACHABLE; -} - void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) { assert(rules.isDots == board.isDots()); int boardScore = 0; if(rules.isDots) - boardScore = countGroundingScoreWhiteMinusBlack(board,area); + boardScore = countDotsScoreWhiteMinusBlack(board,area); else if(rules.scoringRule == Rules::SCORING_AREA) boardScore = countAreaScoreWhiteMinusBlack(board,area); else if(rules.scoringRule == Rules::SCORING_TERRITORY) @@ -747,9 +734,7 @@ void BoardHistory::endAndScoreGameNow(const Board& board, Color area[Board::MAX_ whiteBonusScore += (presumedNextMovePla == P_WHITE ? 0.5f : -0.5f); } - if (!rules.isDots || !isGameFinished) { - setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); - } + setFinalScoreAndWinner(static_cast(boardScore) + whiteBonusScore + whiteHandicapBonusScore + rules.komi); isScored = true; isNoResult = false; isResignation = false; @@ -762,39 +747,12 @@ void BoardHistory::endAndScoreGameNow(const Board& board) { endAndScoreGameNow(board,area); } -bool BoardHistory::isGroundingWinsGame(const Board& board, const Player pla, float& whiteScore) const { - assert(rules.isDots); - - if (pla == P_BLACK && board.whiteScoreIfBlackGrounds + rules.komi < 0) { - whiteScore = board.whiteScoreIfBlackGrounds + rules.komi; - return true; - } - - if (pla == P_WHITE && board.blackScoreIfWhiteGrounds - rules.komi < 0) { - whiteScore = -board.blackScoreIfWhiteGrounds + rules.komi; - return true; - } - - return false; -} - void BoardHistory::endGameIfAllPassAlive(const Board& board) { assert(rules.isDots == board.isDots()); if (rules.isDots) { - bool gameOver = false; - float normalizedWhiteScoreIfGroundingAlive = 0.0f; - - if (board.numLegalMoves == 0) { - // No legal locs to place a dot -> game is over. - gameOver = true; - normalizedWhiteScoreIfGroundingAlive = static_cast(board.numBlackCaptures - board.numWhiteCaptures) + rules.komi; - } else { - gameOver = isGroundingWinsGame(board, presumedNextMovePla, normalizedWhiteScoreIfGroundingAlive); - } - - if (gameOver) { - setFinalScoreAndWinner(normalizedWhiteScoreIfGroundingAlive); + if (const float whiteScoreAfterGrounding = whiteScoreIfGroundingAlive(board); whiteScoreAfterGrounding != 0.0) { + setFinalScoreAndWinner(whiteScoreAfterGrounding); isScored = true; isNoResult = false; isResignation = false; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 2123e1410..9d30cb134 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -104,6 +104,7 @@ struct BoardHistory { BoardHistory(); explicit BoardHistory(const Rules& rules); ~BoardHistory(); + BoardHistory(const Board& board); BoardHistory(const Board& board, Player pla, const Rules& rules, int encorePhase); @@ -173,9 +174,10 @@ struct BoardHistory { void endGameIfAllPassAlive(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); - bool isGroundingWinsGame(const Board& board, Player pla, float& whiteScore) const; + bool doesGroundingWinGame(const Board& board, Player pla) const; + bool doesGroundingWinGame(const Board& board, Player pla, float& whiteScore) const; + float whiteScoreIfGroundingAlive(const Board& board) const; void endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]); - void getAreaNow(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; void setWinnerByResignation(Player pla); @@ -203,7 +205,7 @@ struct BoardHistory { private: bool koHashOccursInHistory(Hash128 koHash, const KoHashTable* rootKoHashTable) const; void setKoRecapBlocked(Loc loc, bool b); - static int countGroundingScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]); + int countDotsScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; int countAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; int countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const; void setFinalScoreAndWinner(float score); diff --git a/cpp/game/dotsboardhistory.cpp b/cpp/game/dotsboardhistory.cpp new file mode 100644 index 000000000..c6a77f364 --- /dev/null +++ b/cpp/game/dotsboardhistory.cpp @@ -0,0 +1,48 @@ +#include "../game/boardhistory.h" + +using namespace std; + +int BoardHistory::countDotsScoreWhiteMinusBlack(const Board& board, Color area[Board::MAX_ARR_SIZE]) const { + assert(rules.isDots); + + const float whiteScore = whiteScoreIfGroundingAlive(board); + Color groundingPlayer = C_EMPTY; + if (whiteScore > 0.0f) { + groundingPlayer = C_WHITE; + } else if (whiteScore < 0.0f) { + groundingPlayer = C_BLACK; + } + return board.calculateOwnershipAndWhiteScore(area, groundingPlayer); +} + +bool BoardHistory::doesGroundingWinGame(const Board& board, const Player pla) const { + float whiteScore; + return doesGroundingWinGame(board, pla, whiteScore); +} + +bool BoardHistory::doesGroundingWinGame(const Board& board, const Player pla, float& whiteScore) const { + assert(rules.isDots); + + whiteScore = whiteScoreIfGroundingAlive(board); + return pla == P_WHITE && whiteScore > 0.0f || pla == P_BLACK && whiteScore < 0.0f; +} + +float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { + assert(rules.isDots); + + if (const float fullWhiteScoreIfBlackGrounds = + static_cast(board.whiteScoreIfBlackGrounds) + whiteBonusScore + whiteHandicapBonusScore + rules.komi; + fullWhiteScoreIfBlackGrounds < 0.0f) { + // Black already won the game + return fullWhiteScoreIfBlackGrounds; + } + + if (const float fullBlackScoreIfWhiteGrounds = + static_cast(board.blackScoreIfWhiteGrounds) - whiteBonusScore - whiteHandicapBonusScore - rules.komi; + fullBlackScoreIfWhiteGrounds < 0.0f) { + // White already won the game + return -fullBlackScoreIfWhiteGrounds; + } + + return 0.0f; +} \ No newline at end of file diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index d844b5025..5055798ae 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -166,21 +166,57 @@ void Board::clearVisited(const vector& locations) const { } } -int Board::calculateGroundingWhiteScore(Color* result) const { +int Board::calculateOwnershipAndWhiteScore(Color* result, const Color groundingPlayer) const { + int whiteCaptures = 0; + int blackCaptures = 0; + for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { const Loc loc = Location::getLoc(x, y, x_size); - if (const State state = getState(loc); isGrounded(state)) { - const Color activeColor = getActiveColor(state); - assert(activeColor != C_EMPTY); - result[loc] = activeColor; - } else { - result[loc] = C_EMPTY; + const State state = getState(loc); + const Color activeColor = getActiveColor(state); + const Color placedDotColor = getPlacedDotColor(state); + Color ownershipColor = C_EMPTY; + if (activeColor != C_EMPTY) { + if (isGrounded(state) || groundingPlayer == C_EMPTY) { + if (placedDotColor != C_EMPTY && activeColor != placedDotColor) { + ownershipColor = activeColor; + if (placedDotColor == P_BLACK) { + blackCaptures++; + } else { + whiteCaptures++; + } + } + } else { + // If the game is finished by grounding by a player, + // Remove its ungrounded dots to get a more refined ownership and score. + if (groundingPlayer == C_WHITE && placedDotColor == P_WHITE) { + ownershipColor = P_BLACK; + whiteCaptures++; + } else if (groundingPlayer == C_BLACK && placedDotColor == P_BLACK) { + ownershipColor = P_WHITE; + blackCaptures++; + } + } } + result[loc] = ownershipColor; } } - return whiteScoreIfBlackGrounds - blackScoreIfWhiteGrounds; + if (groundingPlayer == C_WHITE) { + // White wins by grounding + assert(blackScoreIfWhiteGrounds == whiteCaptures - blackCaptures); + return -blackScoreIfWhiteGrounds; + } + + if (groundingPlayer == C_BLACK) { + // Black wins by grounding + assert(whiteScoreIfBlackGrounds == blackCaptures - whiteCaptures); + return whiteScoreIfBlackGrounds; + } + + assert(numBlackCaptures == blackCaptures && numWhiteCaptures == whiteCaptures); + return numBlackCaptures - numWhiteCaptures; } Board::MoveRecord::MoveRecord( diff --git a/cpp/neuralnet/nninputs.h b/cpp/neuralnet/nninputs.h index a92a1edf4..f7b21d59e 100644 --- a/cpp/neuralnet/nninputs.h +++ b/cpp/neuralnet/nninputs.h @@ -79,23 +79,25 @@ namespace NNInputs { constexpr int NUM_FEATURES_GLOBAL_V_DOTS = 22; constexpr int DOTS_FEATURE_SPATIAL_ON_BOARD = 0; // 0 - constexpr int DOTS_FEATURE_SPATIAL_PLAYER = 1; // 1 - constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP = 2; // 2 - constexpr int DOTS_FEATURE_SPATIAL_PLAYER_CAPTURES = 3; // (3,4,5) - constexpr int DOTS_FEATURE_SPATIAL_PLAYER_SURROUNDINGS = 4; // (3,4,5) - constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_CAPTURES = 5; // (3,4,5) - constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_SURROUNDINGS = 6; // (3,4,5) - constexpr int DOTS_FEATURE_SPATIAL_DEAD_DOTS = 7; // Needed for more correct score calculation - constexpr int DOTS_FEATURE_SPATIAL_GROUNDED = 8; // Analogue of territory (18,19) - constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_0 = 9; // 9 - constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_1 = 10; // 10 - constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_2 = 11; // 11 - constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_3 = 12; // 12 - constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_4 = 13; // 13 - constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED = 14; // 14, TODO: implement later - constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED_PREVIOUS = 15; // 15, TODO: implement later - constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED_PREVIOUS_2 = 16; // 16, TODO: implement later - constexpr int DOTS_FEATURE_SPATIAL_LADDER_WORKING_MOVES = 17; // 17, TODO: implement later + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_ACTIVE = 1; // 1 + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_ACTIVE = 2; // 2 + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_PLACED = 3; // 1 + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_PLACED = 4; // 2 + constexpr int DOTS_FEATURE_SPATIAL_DEAD_DOTS = 5; // Actually scoring + constexpr int DOTS_FEATURE_SPATIAL_GROUNDED = 6; // Analogue of territory (18,19) + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_CAPTURES = 7; // (3,4,5) + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_SURROUNDINGS = 8; // (3,4,5) + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_CAPTURES = 9; // (3,4,5) + constexpr int DOTS_FEATURE_SPATIAL_PLAYER_OPP_SURROUNDINGS = 10; // (3,4,5) + constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_0 = 10; // 9 + constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_1 = 11; // 10 + constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_2 = 12; // 11 + constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_3 = 13; // 12 + constexpr int DOTS_FEATURE_SPATIAL_LAST_MOVE_4 = 14; // 13 + constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED = 15; // 14, TODO: implement later + constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED_PREVIOUS = 16; // 15, TODO: implement later + constexpr int DOTS_FEATURE_SPATIAL_LADDER_CAPTURED_PREVIOUS_2 = 17; // 16, TODO: implement later + constexpr int DOTS_FEATURE_SPATIAL_LADDER_WORKING_MOVES = 18; // 17, TODO: implement later constexpr int DOTS_FEATURE_GLOBAL_KOMI = 0; // 5 constexpr int DOTS_FEATURE_GLOBAL_SUICIDE = 1; // 8 diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index 2fc7b2474..b99dc6859 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -34,9 +34,6 @@ void NNInputs::fillRowVDots( vector bases; board.calculateOneMoveCaptureAndBasePositionsForDots(hist.rules.multiStoneSuicideLegal, captures, bases); - Color grounding[Board::MAX_ARR_SIZE]; - board.calculateGroundingWhiteScore(grounding); - for(int y = 0; y& extraMoves) { - Board board = parseDotsFieldDefault(boardData, extraMoves); +string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { + const Board board = parseDotsFieldDefault(boardData, extraMoves); - Board copy(board); Color result[Board::MAX_ARR_SIZE]; - int whiteScore = copy.calculateGroundingWhiteScore(result); - testAssert(expectedGroundingWhiteScore == whiteScore); + const int whiteScore = board.calculateOwnershipAndWhiteScore(result, groundingPlayer); + testAssert(expectedWhiteScore == whiteScore); std::ostringstream oss; - for (int y = 0; y < copy.y_size; y++) { - for (int x = 0; x < copy.x_size; x++) { - Loc loc = Location::getLoc(x, y, copy.x_size); + for (int y = 0; y < board.y_size; y++) { + for (int x = 0; x < board.x_size; x++) { + const Loc loc = Location::getLoc(x, y, board.x_size); oss << PlayerIO::colorToChar(result[loc]); } oss << endl; } - testAssert(board.isEqualForTesting(copy, true, true)); return oss.str(); } -string getGroundedTerritory(const string& boardData, const int expectedGroundingWhiteScore) { - return getGroundedTerritory(boardData, expectedGroundingWhiteScore, vector()); +void expect( + const char* name, + const Color groundingPlayer, + const std::string& actualField, + const std::string& expectedOwnership, + const int expectedWhiteScore, + const vector& extraMoves = {} +) { + cout << " " << name << ", Grounding Player: " << PlayerIO::colorToChar(groundingPlayer) << endl; + expect(name, getOwnership(actualField, groundingPlayer, expectedWhiteScore, extraMoves), expectedOwnership); } -void Tests::runDotsTerritoryTests() { - expect("Cross", getGroundedTerritory(R"( +void Tests::runDotsOwnershipTests() { + expect("Start Cross", C_EMPTY, R"( ...... ...... ..ox.. ..xo.. ...... ...... -)", 0), +)", R"( ...... ...... @@ -148,92 +154,75 @@ void Tests::runDotsTerritoryTests() { ...... ...... ...... -)"); +)", 0); + + expect("Wins by a base", C_EMPTY, R"( +...... +...... +..ox.. +.oxo.. +...... +...... +)", +R"( +...... +...... +...... +..O... +...... +...... +)", 1, {XYMove(2, 4, P_WHITE)}); - expect("Grounded white", getGroundedTerritory(R"( + expect("Loss by grounding", C_BLACK, R"( ..o... ..o... ..ox.. ..xo.. ...o.. ...o.. -)", 2), - R"( -..O... -..O... -..O... -...O.. -...O.. +)", +R"( +...... +...... ...O.. -)"); +..O... +...... +...... +)", 2); - expect("Grounded black", getGroundedTerritory(R"( + expect("Loss by grounding", C_WHITE, R"( ...x.. ...x.. ..ox.. ..xo.. ..x... ..x... -)", -2), - R"( -...X.. -...X.. -...X.. -..X... -..X... -..X... -)"); - - expect("Grounded white and black", getGroundedTerritory(R"( -..ox.. -..ox.. -..ox.. -..xo.. -..xo.. -..xo.. -)", 0), -R"( -..OX.. -..OX.. -..OX.. -..XO.. -..XO.. -..XO.. -)"); - - expect("Ungrounded white base", getGroundedTerritory(R"( -...... -...... -..ox.. -.oxo.. -...... -...... -)", -2, {XYMove(2, 4, P_WHITE)}), +)", R"( ...... ...... +..X... +...X.. ...... ...... -...... -...... -)"); +)", -2); - expect("Grounded white base", getGroundedTerritory(R"( + expect("Wins by grounding with an ungrounded dot", C_WHITE, R"( ...... +.oox.. +.xxo.. +.oo... +....o. ...... -..ox.. -.oxo.. +)", +R"( ...... ...... -)", 3, {XYMove(2, 4, P_WHITE), XYMove(2, 5, P_WHITE)}), -R"( +.OO... ...... +....X. ...... -..O... -.OOO.. -..O... -..O... -)"); +)", 1, {XYMove(0, 2, P_WHITE)}); } std::pair getCapturingAndBases( @@ -243,7 +232,7 @@ std::pair getCapturingAndBases( ) { Board board = parseDotsFieldDefault(boardData, extraMoves); - Board copy(board); + const Board& copy(board); vector captures; vector bases; @@ -281,7 +270,7 @@ std::pair getCapturingAndBases( // Make sure we didn't change an internal state during calculating testAssert(board.isEqualForTesting(copy, true, true)); - return std::pair(capturesStringStream.str(), basesStringStream.str()); + return {capturesStringStream.str(), basesStringStream.str()}; } void checkCapturingAndBase( diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp index 2f5948a3a..9dada066d 100644 --- a/cpp/tests/testdotsutils.cpp +++ b/cpp/tests/testdotsutils.cpp @@ -2,10 +2,6 @@ using namespace std; -Board parseDotsFieldDefault(const string& input) { - return parseDotsField(input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, vector()); -} - Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { return parseDotsField(input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); } diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index ec247dcd2..108bf824a 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -60,9 +60,7 @@ struct BoardWithMoveRecords { } }; -Board parseDotsFieldDefault(const string& input); - -Board parseDotsFieldDefault(const string& input, const vector& extraMoves); +Board parseDotsFieldDefault(const string& input, const vector& extraMoves = {}); Board parseDotsField(const string& input, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); diff --git a/cpp/tests/tests.h b/cpp/tests/tests.h index 71bc03134..67099578d 100644 --- a/cpp/tests/tests.h +++ b/cpp/tests/tests.h @@ -20,11 +20,12 @@ namespace Tests { void runDotsFieldTests(); void runDotsGroundingTests(); + void runDotsBoardHistoryGroundingTests(); void runDotsPosHashTests(); void runDotsStartPosTests(); void runDotsStressTests(); - void runDotsTerritoryTests(); + void runDotsOwnershipTests(); void runDotsSymmetryTests(); void runDotsCapturingTests(); From 40ac07e3d34170a9b1b1ff359155d7dd7a68f196 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 7 Sep 2025 13:46:49 +0200 Subject: [PATCH 09/42] Extract Dots stress tests to a separated file --- cpp/CMakeLists.txt | 1 + cpp/dataio/sgf.cpp | 2 +- cpp/game/board.cpp | 14 +- cpp/game/board.h | 6 +- cpp/tests/testboardarea.cpp | 2 +- cpp/tests/testdotsbasic.cpp | 282 ------------------------------ cpp/tests/testdotsstress.cpp | 330 +++++++++++++++++++++++++++++++++++ 7 files changed, 338 insertions(+), 299 deletions(-) create mode 100644 cpp/tests/testdotsstress.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e242b702d..bfc885219 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -287,6 +287,7 @@ add_executable(katago tests/testdotsutils.cpp tests/testdotsbasic.cpp tests/testdotsstartposes.cpp + tests/testdotsstress.cpp tests/testdotsextra.cpp tests/testbook.cpp tests/testcommon.cpp diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 715352a28..38aff7cf3 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1162,7 +1162,7 @@ Sgf::PositionSample Sgf::PositionSample::ofJsonLine(const string& s) { bool isDots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - sample.board = Board::parseBoard(xSize, ySize, data["board"].get(), '/', Rules(isDots)); + sample.board = Board::parseBoard(xSize, ySize, data["board"].get(), Rules(isDots), '/'); sample.nextPla = PlayerIO::parsePlayer(data["nextPla"].get()); vector moveLocs = data["moveLocs"].get>(); vector movePlas = data["movePlas"].get>(); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 72b921c64..5a397358b 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -2816,19 +2816,11 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { return s; } -Board Board::parseBoard(int xSize, int ySize, const string& s) { - return parseBoard(xSize, ySize, s, '\n', Rules()); -} - Board Board::parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter) { - return parseBoard(xSize, ySize, s, lineDelimiter, Rules()); -} - -Board Board::parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules) { - return parseBoard(xSize, ySize, s, '\n', rules); + return parseBoard(xSize, ySize, s, Rules::DEFAULT_GO, lineDelimiter); } -Board Board::parseBoard(int xSize, int ySize, const string& s, char lineDelimiter, const Rules& rules) { +Board Board::parseBoard(int xSize, int ySize, const string& s, const Rules& rules, char lineDelimiter) { Board board(xSize,ySize,rules); vector lines = Global::split(Global::trim(s),lineDelimiter); @@ -2908,7 +2900,7 @@ Board Board::ofJson(const nlohmann::json& data) { bool dots = data.value(DOTS_KEY, false); int xSize = data["xSize"].get(); int ySize = data["ySize"].get(); - Board board = parseBoard(xSize, ySize, data["stones"].get(), '|', Rules(dots)); + Board board = parseBoard(xSize, ySize, data["stones"].get(), Rules(dots), '|'); board.setSimpleKoLoc(Location::ofStringAllowNull(data.value("koLoc", "null"),board)); board.numBlackCaptures = data["numBlackCaptures"].get(); board.numWhiteCaptures = data["numWhiteCaptures"].get(); diff --git a/cpp/game/board.h b/cpp/game/board.h index 5b251a3a8..e1159f5a3 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -388,10 +388,8 @@ struct Board bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const; bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo, bool checkRules) const; - static Board parseBoard(int xSize, int ySize, const std::string& s); - static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter); - static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules); - static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter, const Rules& rules); + static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter = '\n'); + static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules, char lineDelimiter = '\n'); std::string toString() const; static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); static std::string toStringSimple(const Board& board, char lineDelimiter); diff --git a/cpp/tests/testboardarea.cpp b/cpp/tests/testboardarea.cpp index 7248b3f29..ad32835b0 100644 --- a/cpp/tests/testboardarea.cpp +++ b/cpp/tests/testboardarea.cpp @@ -19,7 +19,7 @@ void Tests::runBoardAreaTests() { bool safeBigTerritories = safeBigTerritoriesBuf[mode/2]; bool unsafeBigTerritories = unsafeBigTerritoriesBuf[mode/2]; bool nonPassAliveStones = nonPassAliveStonesBuf[mode/2]; - Board copy(board); + const Board& copy(board); copy.calculateArea(result,nonPassAliveStones,safeBigTerritories,unsafeBigTerritories,multiStoneSuicideLegal); out << "Safe big territories " << safeBigTerritories << " " << "Unsafe big territories " << unsafeBigTerritories << " " diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index 4f6ab901b..caa297b2b 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -1,5 +1,3 @@ -#include - #include "../tests/tests.h" #include "../tests/testdotsutils.h" @@ -7,7 +5,6 @@ #include "../program/playutils.h" using namespace std; -using namespace std::chrono; using namespace TestCommon; void checkDotsField(const string& description, const string& input, bool captureEmptyBases, bool freeCapturedDots, const std::function& check) { @@ -795,283 +792,4 @@ xxxxxxxxxx testAssert(dotsFieldWithErasedTerritory.pos_hash == dotsFieldWithSurrounding.pos_hash); } -} - -string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { - Board boardCopy(initialBoard); - BoardHistory boardHistory(boardCopy, P_BLACK, boardCopy.rules, 0); - for (const Board::MoveRecord& moveRecord : moveRecords) { - boardHistory.makeBoardMoveAssumeLegal(boardCopy, moveRecord.loc, moveRecord.pla, nullptr); - } - std::ostringstream sgfStringStream; - WriteSgf::writeSgf(sgfStringStream, "blue", "red", boardHistory, {}); - return sgfStringStream.str(); -} - -/** - * Calculates the grounding and result captures without using the grounding flag and incremental calculations. - * It's used for testing to verify incremental grounding algorithms. - */ -void validateGrounding( - const Board& boardBeforeGrounding, - const Board& boardAfterGrounding, - const Player pla, - const vector& moveRecords) { - unordered_set visited_locs; - assert(pla == P_BLACK || pla == P_WHITE); - - int expectedNumBlackCaptures = 0; - int expectedNumWhiteCaptures = 0; - const Player opp = getOpp(pla); - for (int y = 0; y < boardBeforeGrounding.y_size; y++) { - for (int x = 0; x < boardBeforeGrounding.x_size; x++) { - Loc loc = Location::getLoc(x, y, boardBeforeGrounding.x_size); - const State state = boardBeforeGrounding.getState(Location::getLoc(x, y, boardBeforeGrounding.x_size)); - - if (const Color activeColor = getActiveColor(state); activeColor == pla) { - if (visited_locs.count(loc) > 0) - continue; - - bool grounded = false; - - vector walkStack; - vector baseLocs; - walkStack.push_back(loc); - - // Find active territory and calculate its grounding state. - while (!walkStack.empty()) { - Loc curLoc = walkStack.back(); - walkStack.pop_back(); - - if (const Color curActiveColor = getActiveColor(boardBeforeGrounding.getState(curLoc)); curActiveColor == pla) { - if (visited_locs.count(curLoc) == 0) { - visited_locs.insert(curLoc); - baseLocs.push_back(curLoc); - boardBeforeGrounding.forEachAdjacent(curLoc, [&](const Loc& adjLoc) { - walkStack.push_back(adjLoc); - }); - } - } else if (curActiveColor == C_WALL) { - grounded = true; - } - } - - for (const Loc& baseLoc : baseLocs) { - const Color placedDotColor = getPlacedDotColor(boardBeforeGrounding.getState(baseLoc)); - - if (!grounded) { - // If the territory is not grounded, it becomes dead. - // Freed dots don't count because they don't add a value to the opp score (assume they just become be placed). - if (placedDotColor == pla) { - if (pla == P_BLACK) { - expectedNumBlackCaptures++; - } else { - expectedNumWhiteCaptures++; - } - } - } else { - State baseLocState = boardAfterGrounding.getState(baseLoc); - // This check on placed dot color is redundant. - // However, currently it's not possible to always ground empty locs in some rare cases due to limitations of incremental grounding algorithm. - // Fortunately, they don't affect the resulting score. - if (!isGrounded(baseLocState) && getPlacedDotColor(baseLocState) != C_EMPTY) { - Global::fatalError("Loc (" + to_string(Location::getX(baseLoc, boardBeforeGrounding.x_size)) + "; " + - to_string(Location::getY(baseLoc, boardBeforeGrounding.x_size)) + ") " + - " should be grounded. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); - } - - // If the territory is grounded, count dead dots of the opp player. - if (placedDotColor == opp) { - if (pla == P_BLACK) { - expectedNumWhiteCaptures++; - } else { - expectedNumBlackCaptures++; - } - } - } - } - } else if (activeColor == opp) { // In the case of opp active color, counts only captured dots - if (getPlacedDotColor(state) == pla) { - if (pla == P_BLACK) { - expectedNumBlackCaptures++; - } else { - expectedNumWhiteCaptures++; - } - } - } - } - } - - if (expectedNumBlackCaptures != boardAfterGrounding.numBlackCaptures || expectedNumWhiteCaptures != boardAfterGrounding.numWhiteCaptures) { - Global::fatalError("expectedNumBlackCaptures (" + to_string(expectedNumBlackCaptures) + ")" + - " == board.numBlackCaptures (" + to_string(boardAfterGrounding.numBlackCaptures) + ")" + - " && expectedNumWhiteCaptures (" + to_string(expectedNumWhiteCaptures) + ")" + - " == board.numWhiteCaptures (" + to_string(boardAfterGrounding.numWhiteCaptures) + ")" + - " check is failed. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); - } -} - -void runDotsStressTestsInternal( - int x_size, - int y_size, - int gamesCount, - bool dotsGame, - int startPos, - bool dotsCaptureEmptyBase, - float komi, - bool suicideAllowed, - float groundingStartCoef, - float groundingEndCoef, - bool performExtraChecks - ) { - assert(groundingStartCoef >= 0 && groundingStartCoef <= 1); - assert(groundingEndCoef >= 0 && groundingEndCoef <= 1); - assert(groundingEndCoef >= groundingStartCoef); - - cout << " Random games" << endl; - cout << " Game type: " << (dotsGame ? "Dots" : "Go") << endl; - cout << " Start position: " << Rules::writeStartPosRule(startPos) << endl; - if (dotsGame) { - cout << " Capture empty bases: " << boolalpha << dotsCaptureEmptyBase << endl; - } - cout << " Extra checks: " << boolalpha << performExtraChecks << endl; -#ifdef NDEBUG - cout << " Build: Release" << endl; -#else - cout << " Build: Debug" << endl; -#endif - cout << " Size: " << x_size << ":" << y_size << endl; - cout << " Komi: " << komi << endl; - cout << " Suicide: " << boolalpha << suicideAllowed << endl; - cout << " Games count: " << gamesCount << endl; - - const auto start = high_resolution_clock::now(); - - Rand rand("runDotsStressTests"); - - Rules rules = dotsGame ? Rules(dotsGame, startPos, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); - auto initialBoard = Board(x_size, y_size, rules); - - vector randomMoves = vector(); - randomMoves.reserve(initialBoard.numLegalMoves); - - for(int y = 0; y < initialBoard.y_size; y++) { - for(int x = 0; x < initialBoard.x_size; x++) { - Loc loc = Location::getLoc(x, y, initialBoard.x_size); - if (initialBoard.getColor(loc) == C_EMPTY) { // Filter out initial poses - randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); - } - } - } - - assert(randomMoves.size() == initialBoard.numLegalMoves); - - int movesCount = 0; - int blackWinsCount = 0; - int whiteWinsCount = 0; - int drawsCount = 0; - int groundingCount = 0; - - auto moveRecords = vector(); - - for (int n = 0; n < gamesCount; n++) { - rand.shuffle(randomMoves); - moveRecords.clear(); - - auto board = Board(initialBoard.x_size, initialBoard.y_size, rules); - - Loc lastLoc = Board::NULL_LOC; - - int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * initialBoard.numLegalMoves; - Player pla = P_BLACK; - for(size_t index = 0; index < randomMoves.size(); index++) { - lastLoc = moveRecords.size() >= tryGroundingAfterMove ? Board::PASS_LOC : randomMoves[index]; - - if (board.isLegal(lastLoc, pla, suicideAllowed, false)) { - Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); - movesCount++; - moveRecords.push_back(moveRecord); - pla = getOpp(pla); - } - - if (lastLoc == Board::PASS_LOC) { - groundingCount++; - int scoreDiff; - int oppScoreIfGrounding; - Player lastPla = moveRecords.back().pla; - if (lastPla == P_BLACK) { - scoreDiff = board.numBlackCaptures - board.numWhiteCaptures; - oppScoreIfGrounding = board.whiteScoreIfBlackGrounds; - } else { - scoreDiff = board.numWhiteCaptures - board.numBlackCaptures; - oppScoreIfGrounding = board.blackScoreIfWhiteGrounds; - } - if (scoreDiff != oppScoreIfGrounding) { - Global::fatalError("scoreDiff (" + to_string(scoreDiff) + ") == oppScoreIfGrounding (" + to_string(oppScoreIfGrounding) + ") check is failed. " + - "Sgf: " + moveRecordsToSgf(initialBoard, moveRecords)); - } - if (performExtraChecks) { - Board boardBeforeGrounding(board); - boardBeforeGrounding.undo(moveRecords.back()); - validateGrounding(boardBeforeGrounding, board, lastPla, moveRecords); - } - break; - } - } - - if (dotsGame && suicideAllowed && lastLoc != Board::PASS_LOC) { - testAssert(0 == board.numLegalMoves); - } - - if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { - whiteWinsCount++; - } else if (whiteScore < 0) { - blackWinsCount++; - } else { - drawsCount++; - } - - if (performExtraChecks) { - while (!moveRecords.empty()) { - board.undo(moveRecords.back()); - moveRecords.pop_back(); - } - - testAssert(initialBoard.isEqualForTesting(board, true, false)); - } - } - - const auto end = high_resolution_clock::now(); - auto durationNs = duration_cast(end - start); - - cout.precision(4); - cout << " Elapsed time: " << duration_cast(durationNs).count() << " ms" << endl; - cout << " Number of games per second: " << static_cast(static_cast(gamesCount) / durationNs.count() * 1000000000) << endl; - cout << " Number of moves per second: " << static_cast(static_cast(movesCount) / durationNs.count() * 1000000000) << endl; - cout << " Number of moves per game: " << static_cast(static_cast(movesCount) / gamesCount) << endl; - cout << " Time per game: " << static_cast(durationNs.count()) / gamesCount / 1000000 << " ms" << endl; - cout << " Black wins: " << blackWinsCount << " (" << static_cast(blackWinsCount) / gamesCount << ")" << endl; - cout << " White wins: " << whiteWinsCount << " (" << static_cast(whiteWinsCount) / gamesCount << ")" << endl; - cout << " Draws: " << drawsCount << " (" << static_cast(drawsCount) / gamesCount << ")" << endl; - cout << " Groundings: " << groundingCount << " (" << static_cast(groundingCount) / gamesCount << ")" << endl; -} - -void Tests::runDotsStressTests() { - cout << "Running dots stress tests" << endl; - - cout << " Max territory" << endl; - Board board = Board(39, 32, Rules::DEFAULT_DOTS); - for(int y = 0; y < board.y_size; y++) { - for(int x = 0; x < board.x_size; x++) { - const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; - board.playMoveAssumeLegal(Location::getLoc(x, y, board.x_size), pla); - } - } - testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); - testAssert(0 == board.numLegalMoves); - - runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, true); - runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, 0.5f, false, 0.8f, 1.0f, true); - - runDotsStressTestsInternal(39, 32, 100000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, false); } \ No newline at end of file diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp new file mode 100644 index 000000000..6c296334e --- /dev/null +++ b/cpp/tests/testdotsstress.cpp @@ -0,0 +1,330 @@ +#include + +#include "../tests/tests.h" +#include "../tests/testdotsutils.h" + +using namespace std; +using namespace std::chrono; +using namespace TestCommon; + +string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { + Board boardCopy(initialBoard); + BoardHistory boardHistory(boardCopy, P_BLACK, boardCopy.rules, 0); + for (const Board::MoveRecord& moveRecord : moveRecords) { + boardHistory.makeBoardMoveAssumeLegal(boardCopy, moveRecord.loc, moveRecord.pla, nullptr); + } + std::ostringstream sgfStringStream; + WriteSgf::writeSgf(sgfStringStream, "blue", "red", boardHistory, {}); + return sgfStringStream.str(); +} + +/** + * Calculates the grounding and result captures without using the grounding flag and incremental calculations. + * It's used for testing to verify incremental grounding algorithms. + */ +void validateGrounding( + const Board& boardBeforeGrounding, + const Board& boardAfterGrounding, + const Player pla, + const vector& moveRecords) { + unordered_set visited_locs; + assert(pla == P_BLACK || pla == P_WHITE); + + int expectedNumBlackCaptures = 0; + int expectedNumWhiteCaptures = 0; + const Player opp = getOpp(pla); + for (int y = 0; y < boardBeforeGrounding.y_size; y++) { + for (int x = 0; x < boardBeforeGrounding.x_size; x++) { + Loc loc = Location::getLoc(x, y, boardBeforeGrounding.x_size); + const State state = boardBeforeGrounding.getState(Location::getLoc(x, y, boardBeforeGrounding.x_size)); + + if (const Color activeColor = getActiveColor(state); activeColor == pla) { + if (visited_locs.count(loc) > 0) + continue; + + bool grounded = false; + + vector walkStack; + vector baseLocs; + walkStack.push_back(loc); + + // Find active territory and calculate its grounding state. + while (!walkStack.empty()) { + Loc curLoc = walkStack.back(); + walkStack.pop_back(); + + if (const Color curActiveColor = getActiveColor(boardBeforeGrounding.getState(curLoc)); curActiveColor == pla) { + if (visited_locs.count(curLoc) == 0) { + visited_locs.insert(curLoc); + baseLocs.push_back(curLoc); + boardBeforeGrounding.forEachAdjacent(curLoc, [&](const Loc& adjLoc) { + walkStack.push_back(adjLoc); + }); + } + } else if (curActiveColor == C_WALL) { + grounded = true; + } + } + + for (const Loc& baseLoc : baseLocs) { + const Color placedDotColor = getPlacedDotColor(boardBeforeGrounding.getState(baseLoc)); + + if (!grounded) { + // If the territory is not grounded, it becomes dead. + // Freed dots don't count because they don't add a value to the opp score (assume they just become be placed). + if (placedDotColor == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } else { + State baseLocState = boardAfterGrounding.getState(baseLoc); + // This check on placed dot color is redundant. + // However, currently it's not possible to always ground empty locs in some rare cases due to limitations of incremental grounding algorithm. + // Fortunately, they don't affect the resulting score. + if (!isGrounded(baseLocState) && getPlacedDotColor(baseLocState) != C_EMPTY) { + Global::fatalError("Loc (" + to_string(Location::getX(baseLoc, boardBeforeGrounding.x_size)) + "; " + + to_string(Location::getY(baseLoc, boardBeforeGrounding.x_size)) + ") " + + " should be grounded. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } + + // If the territory is grounded, count dead dots of the opp player. + if (placedDotColor == opp) { + if (pla == P_BLACK) { + expectedNumWhiteCaptures++; + } else { + expectedNumBlackCaptures++; + } + } + } + } + } else if (activeColor == opp) { // In the case of opp active color, counts only captured dots + if (getPlacedDotColor(state) == pla) { + if (pla == P_BLACK) { + expectedNumBlackCaptures++; + } else { + expectedNumWhiteCaptures++; + } + } + } + } + } + + if (expectedNumBlackCaptures != boardAfterGrounding.numBlackCaptures || expectedNumWhiteCaptures != boardAfterGrounding.numWhiteCaptures) { + Global::fatalError("expectedNumBlackCaptures (" + to_string(expectedNumBlackCaptures) + ")" + + " == board.numBlackCaptures (" + to_string(boardAfterGrounding.numBlackCaptures) + ")" + + " && expectedNumWhiteCaptures (" + to_string(expectedNumWhiteCaptures) + ")" + + " == board.numWhiteCaptures (" + to_string(boardAfterGrounding.numWhiteCaptures) + ")" + + " check is failed. Sgf: " + moveRecordsToSgf(boardBeforeGrounding, moveRecords)); + } +} + +void validateStatesAndCaptures(const Board& board, const vector& moveRecords) { + int expectedNumBlackCaptures = 0; + int expectedNumWhiteCaptures = 0; + int expectedPlacedDotsCount = -board.rules.getNumOfStartPosStones(); + + for (int y = 0; y < board.y_size; y++) { + for (int x = 0; x < board.x_size; x++) { + const State state = board.getState(Location::getLoc(x, y, board.x_size)); + const Color activeColor = getActiveColor(state); + const Color placedDotColor = getPlacedDotColor(state); + const Color emptyTerritoryColor = getEmptyTerritoryColor(state); + + if (placedDotColor != C_EMPTY) { + expectedPlacedDotsCount++; + } + + if (activeColor == C_BLACK) { + assert(C_EMPTY == emptyTerritoryColor); + if (placedDotColor == C_WHITE) { + expectedNumWhiteCaptures++; + } + } else if (activeColor == C_WHITE) { + assert(C_EMPTY == emptyTerritoryColor); + if (placedDotColor == C_BLACK) { + expectedNumBlackCaptures++; + } + } else { + assert(placedDotColor == C_EMPTY); + //assert(!isTerritory(state)); + } + } + } + + const int actualPlacedDotsCount = moveRecords.size() - (moveRecords.back().loc == Board::PASS_LOC ? 1 : 0); + assert(expectedPlacedDotsCount == actualPlacedDotsCount); + assert(expectedNumBlackCaptures == board.numBlackCaptures); + assert(expectedNumWhiteCaptures == board.numWhiteCaptures); +} + +void runDotsStressTestsInternal( + int x_size, + int y_size, + int gamesCount, + bool dotsGame, + int startPos, + bool dotsCaptureEmptyBase, + float komi, + bool suicideAllowed, + float groundingStartCoef, + float groundingEndCoef, + bool performExtraChecks + ) { + assert(groundingStartCoef >= 0 && groundingStartCoef <= 1); + assert(groundingEndCoef >= 0 && groundingEndCoef <= 1); + assert(groundingEndCoef >= groundingStartCoef); + + cout << " Random games" << endl; + cout << " Game type: " << (dotsGame ? "Dots" : "Go") << endl; + cout << " Start position: " << Rules::writeStartPosRule(startPos) << endl; + if (dotsGame) { + cout << " Capture empty bases: " << boolalpha << dotsCaptureEmptyBase << endl; + } + cout << " Extra checks: " << boolalpha << performExtraChecks << endl; +#ifdef NDEBUG + cout << " Build: Release" << endl; +#else + cout << " Build: Debug" << endl; +#endif + cout << " Size: " << x_size << ":" << y_size << endl; + cout << " Komi: " << komi << endl; + cout << " Suicide: " << boolalpha << suicideAllowed << endl; + cout << " Games count: " << gamesCount << endl; + + const auto start = high_resolution_clock::now(); + + Rand rand("runDotsStressTests"); + + Rules rules = dotsGame ? Rules(dotsGame, startPos, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); + auto initialBoard = Board(x_size, y_size, rules); + + vector randomMoves = vector(); + randomMoves.reserve(initialBoard.numLegalMoves); + + for(int y = 0; y < initialBoard.y_size; y++) { + for(int x = 0; x < initialBoard.x_size; x++) { + Loc loc = Location::getLoc(x, y, initialBoard.x_size); + if (initialBoard.getColor(loc) == C_EMPTY) { // Filter out initial poses + randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); + } + } + } + + assert(randomMoves.size() == initialBoard.numLegalMoves); + + int movesCount = 0; + int blackWinsCount = 0; + int whiteWinsCount = 0; + int drawsCount = 0; + int groundingCount = 0; + + auto moveRecords = vector(); + + for (int n = 0; n < gamesCount; n++) { + rand.shuffle(randomMoves); + moveRecords.clear(); + + auto board = Board(initialBoard.x_size, initialBoard.y_size, rules); + + Loc lastLoc = Board::NULL_LOC; + + int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * initialBoard.numLegalMoves; + Player pla = P_BLACK; + for(short randomMove : randomMoves) { + lastLoc = moveRecords.size() >= tryGroundingAfterMove ? Board::PASS_LOC : randomMove; + + if (board.isLegal(lastLoc, pla, suicideAllowed, false)) { + Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); + movesCount++; + moveRecords.push_back(moveRecord); + pla = getOpp(pla); + } + + if (lastLoc == Board::PASS_LOC) { + groundingCount++; + int scoreDiff; + int oppScoreIfGrounding; + Player lastPla = moveRecords.back().pla; + if (lastPla == P_BLACK) { + scoreDiff = board.numBlackCaptures - board.numWhiteCaptures; + oppScoreIfGrounding = board.whiteScoreIfBlackGrounds; + } else { + scoreDiff = board.numWhiteCaptures - board.numBlackCaptures; + oppScoreIfGrounding = board.blackScoreIfWhiteGrounds; + } + if (scoreDiff != oppScoreIfGrounding) { + Global::fatalError("scoreDiff (" + to_string(scoreDiff) + ") == oppScoreIfGrounding (" + to_string(oppScoreIfGrounding) + ") check is failed. " + + "Sgf: " + moveRecordsToSgf(initialBoard, moveRecords)); + } + break; + } + } + + if (performExtraChecks) { + if (lastLoc == Board::PASS_LOC) { + Board boardBeforeGrounding(board); + boardBeforeGrounding.undo(moveRecords.back()); + validateGrounding(boardBeforeGrounding, board, moveRecords.back().pla, moveRecords); + } + validateStatesAndCaptures(board, moveRecords); + } + + if (dotsGame && suicideAllowed && lastLoc != Board::PASS_LOC) { + testAssert(0 == board.numLegalMoves); + } + + if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { + whiteWinsCount++; + } else if (whiteScore < 0) { + blackWinsCount++; + } else { + drawsCount++; + } + + if (performExtraChecks) { + while (!moveRecords.empty()) { + board.undo(moveRecords.back()); + moveRecords.pop_back(); + } + + testAssert(initialBoard.isEqualForTesting(board, true, false)); + } + } + + const auto end = high_resolution_clock::now(); + auto durationNs = duration_cast(end - start); + + cout.precision(4); + cout << " Elapsed time: " << duration_cast(durationNs).count() << " ms" << endl; + cout << " Number of games per second: " << static_cast(static_cast(gamesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per second: " << static_cast(static_cast(movesCount) / durationNs.count() * 1000000000) << endl; + cout << " Number of moves per game: " << static_cast(static_cast(movesCount) / gamesCount) << endl; + cout << " Time per game: " << static_cast(durationNs.count()) / gamesCount / 1000000 << " ms" << endl; + cout << " Black wins: " << blackWinsCount << " (" << static_cast(blackWinsCount) / gamesCount << ")" << endl; + cout << " White wins: " << whiteWinsCount << " (" << static_cast(whiteWinsCount) / gamesCount << ")" << endl; + cout << " Draws: " << drawsCount << " (" << static_cast(drawsCount) / gamesCount << ")" << endl; + cout << " Groundings: " << groundingCount << " (" << static_cast(groundingCount) / gamesCount << ")" << endl; +} + +void Tests::runDotsStressTests() { + cout << "Running dots stress tests" << endl; + + cout << " Max territory" << endl; + Board board = Board(39, 32, Rules::DEFAULT_DOTS); + for(int y = 0; y < board.y_size; y++) { + for(int x = 0; x < board.x_size; x++) { + const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; + board.playMoveAssumeLegal(Location::getLoc(x, y, board.x_size), pla); + } + } + testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); + testAssert(0 == board.numLegalMoves); + + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, 0.5f, false, 0.8f, 1.0f, true); + + runDotsStressTestsInternal(39, 32, 100000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, false); +} \ No newline at end of file From 5e2a4e73f464cb8c9f8dc4d8b66fba6f7ea9717a Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 7 Sep 2025 15:01:12 +0200 Subject: [PATCH 10/42] Don't consider PASS (Grounding) move in case it doesn't win the game --- cpp/neuralnet/nneval.cpp | 13 ++++++++++++- cpp/program/play.cpp | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index 341ab8d3f..da472a224 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -858,7 +858,17 @@ void NNEvaluator::evaluate( assert(nextPlayer == history.presumedNextMovePla); for(int i = 0; i= 13 && ySize >= 13) { @@ -872,6 +882,7 @@ void NNEvaluator::evaluate( } } + legalCount = 0; for(int i = 0; i Date: Sun, 7 Sep 2025 21:36:25 +0200 Subject: [PATCH 11/42] Split `playMoveAssumeLegalDots` that is used much more often than `playMoveRecordedDots` It works much faster than `playMoveRecordedDots` although it leads to some code duplication Introduce some other minor Dots field optimizations --- cpp/game/board.cpp | 2 +- cpp/game/board.h | 29 ++++-- cpp/game/dotsfield.cpp | 165 ++++++++++++++++++++++++----------- cpp/tests/testdotsstress.cpp | 17 ++-- 4 files changed, 145 insertions(+), 68 deletions(-) diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 5a397358b..0fc43b285 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -826,7 +826,7 @@ bool Board::playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal) //Plays the specified move, assuming it is legal, and returns a MoveRecord for the move Board::MoveRecord Board::playMoveRecorded(const Loc loc, const Player pla) { if (rules.isDots) { - return playMoveAssumeLegalDots(loc, pla); + return playMoveRecordedDots(loc, pla); } uint8_t capDirs = 0; diff --git a/cpp/game/board.h b/cpp/game/board.h index e1159f5a3..d16c7aa53 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -429,8 +429,6 @@ struct Board private: // Dots game data - mutable std::array unconnectedLocationsBuffer = std::array(); - mutable int unconnectedLocationsBufferSize = 0; mutable std::vector closureOrInvalidateLocsBuffer = std::vector(); mutable std::vector territoryLocationsBuffer = std::vector(); mutable std::vector walkStack = std::vector(); @@ -439,15 +437,28 @@ struct Board // Dots game functions [[nodiscard]] bool wouldBeCaptureDots(Loc loc, Player pla) const; [[nodiscard]] bool isSuicideDots(Loc loc, Player pla) const; - MoveRecord playMoveAssumeLegalDots(Loc loc, Player pla); - MoveRecord tryPlayMove(Loc loc, Player pla, bool isSuicideLegal); + void playMoveAssumeLegalDots(Loc loc, Player pla); + MoveRecord playMoveRecordedDots(Loc loc, Player pla); + MoveRecord tryPlayMoveRecordedDots(Loc loc, Player pla, bool isSuicideLegal); void undoDots(MoveRecord& moveRecord); std::vector fillGrounding(Loc loc); - Base captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp, bool& isGrounded); - std::vector tryCapture(Loc loc, Player pla, bool emptyBaseCapturing, bool& atLeastOneRealBaseIsGrounded); - std::vector ground(Player pla, std::vector& emptyBaseInvalidatePositions); - void getUnconnectedLocations(Loc loc, Player pla) const; - void checkAndAddUnconnectedLocation(Player checkPla,Player currentPla,Loc addLoc1,Loc addLoc2) const; + void captureWhenEmptyTerritoryBecomesRealBase(Loc initLoc, Player opp, std::vector& bases, bool& isGrounded); + void tryCapture( + Loc loc, + Player pla, + const std::array& unconnectedLocations, + int unconnectedLocationsSize, + bool& atLeastOneRealBaseIsGrounded, + std::vector& bases); + void ground(Player pla, std::vector& emptyBaseInvalidatePositions, std::vector& bases); + std::array getUnconnectedLocations(Loc loc, Player pla, int& size) const; + void checkAndAddUnconnectedLocation( + std::array& unconnectedLocationsBuffer, + int& size, + Player checkPla, + Player currentPla, + Loc addLoc1, + Loc addLoc2) const; void tryGetCounterClockwiseClosure(Loc initialLoc, Loc startLoc, Player pla) const; Base buildBase(const std::vector& closure, Player pla); void getTerritoryLocations(Player pla, Loc firstLoc, bool grounding, bool& createRealBase) const; diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 5055798ae..27b4d6cfe 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -249,7 +249,7 @@ bool Board::isSuicideDots(const Loc loc, const Player pla) const { bool Board::wouldBeCaptureDots(const Loc loc, const Player pla) const { // TODO: optimize and get rid of `const_cast` - auto moveRecord = const_cast(this)->tryPlayMove(loc, pla, false); + auto moveRecord = const_cast(this)->tryPlayMoveRecordedDots(loc, pla, false); bool result = false; @@ -267,13 +267,66 @@ bool Board::wouldBeCaptureDots(const Loc loc, const Player pla) const { return result; } -Board::MoveRecord Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { - MoveRecord result = tryPlayMove(loc, pla, true); +Board::MoveRecord Board::playMoveRecordedDots(const Loc loc, const Player pla) { + const MoveRecord& result = tryPlayMoveRecordedDots(loc, pla, true); assert(result.pla == pla); - return std::move(result); + return result; } -Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLegal) { +void Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { + const State originalState = getState(loc); + + if (loc == PASS_LOC) { + auto initEmptyBaseInvalidateLocations = vector(); + auto bases = vector(); + ground(pla, initEmptyBaseInvalidateLocations, bases); + } else { + colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); + const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; + pos_hash ^= hashValue; + numLegalMoves--; + + bool atLeastOneRealBaseIsGrounded = false; + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, pla, unconnectedLocationsSize); + bool capturing = false; + if (unconnectedLocationsSize >= 2) { + auto bases = vector(); + tryCapture(loc, pla, unconnectedLocations, unconnectedLocationsSize, atLeastOneRealBaseIsGrounded, bases); + capturing = !bases.empty(); + } + + const Color opp = getOpp(pla); + if (!capturing) { + if (getEmptyTerritoryColor(originalState) == opp) { + vector oppBases; + captureWhenEmptyTerritoryBecomesRealBase(loc, opp, oppBases, atLeastOneRealBaseIsGrounded); + } + } else if (isWithinEmptyTerritory(originalState, opp)) { + invalidateAdjacentEmptyTerritoryIfNeeded(loc); + } + + if (pla == P_BLACK) { + whiteScoreIfBlackGrounds++; + } else if (pla == P_WHITE) { + blackScoreIfWhiteGrounds++; + } + + if (atLeastOneRealBaseIsGrounded) { + fillGrounding(loc); + } else if( + const Player locActivePlayer = getColor(loc); // Can't use pla because of a possible suicidal move + isGroundedOrWall(getState(Location::xm1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xym1(loc, x_size)), locActivePlayer) || + isGroundedOrWall(getState(Location::xp1y(loc)), locActivePlayer) || + isGroundedOrWall(getState(Location::xyp1(loc, x_size)), locActivePlayer) + ) { + fillGrounding(loc); + } + } +} + +Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool isSuicideLegal) { State originalState = getState(loc); vector bases; @@ -281,8 +334,7 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe vector newGroundingLocations; if (loc == PASS_LOC) { - initEmptyBaseInvalidateLocations = vector(); - bases = ground(pla, initEmptyBaseInvalidateLocations); + ground(pla, initEmptyBaseInvalidateLocations, bases); } else { colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; @@ -290,13 +342,17 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe numLegalMoves--; bool atLeastOneRealBaseIsGrounded = false; - bases = tryCapture(loc, pla, false, atLeastOneRealBaseIsGrounded); + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, pla, unconnectedLocationsSize); + if (unconnectedLocationsSize >= 2) { + tryCapture(loc, pla, unconnectedLocations, unconnectedLocationsSize, atLeastOneRealBaseIsGrounded, bases); + } const Color opp = getOpp(pla); if (bases.empty()) { if (getEmptyTerritoryColor(originalState) == opp) { if (isSuicideLegal) { - bases.push_back(captureWhenEmptyTerritoryBecomesRealBase(loc, opp, atLeastOneRealBaseIsGrounded)); + captureWhenEmptyTerritoryBecomesRealBase(loc, opp, bases, atLeastOneRealBaseIsGrounded); } else { colors[loc] = originalState; pos_hash ^= hashValue; @@ -304,11 +360,9 @@ Board::MoveRecord Board::tryPlayMove(Loc loc, Player pla, const bool isSuicideLe return {}; } } - } else { - if (isWithinEmptyTerritory(originalState, opp)) { - invalidateAdjacentEmptyTerritoryIfNeeded(loc); - initEmptyBaseInvalidateLocations = vector(closureOrInvalidateLocsBuffer); - } + } else if (isWithinEmptyTerritory(originalState, opp)) { + invalidateAdjacentEmptyTerritoryIfNeeded(loc); + initEmptyBaseInvalidateLocations = vector(closureOrInvalidateLocsBuffer); } if (pla == P_BLACK) { @@ -419,7 +473,11 @@ vector Board::fillGrounding(const Loc loc) { return groundedLocs; } -Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, const Player opp, bool& isGrounded) { +void Board::captureWhenEmptyTerritoryBecomesRealBase( + const Loc initLoc, + const Player opp, + vector& bases, + bool& isGrounded) { Loc loc = initLoc; // Searching for an opponent dot that makes a closure that contains the `initialPosition`. @@ -430,34 +488,40 @@ Board::Base Board::captureWhenEmptyTerritoryBecomesRealBase(const Loc initLoc, c // Try to peek an active opposite player dot if (getColor(loc) != opp) continue; - vector oppBases = tryCapture(loc, opp, true, isGrounded); - // The found base always should be real and include the `iniLoc` - for (const Base& oppBase : oppBases) { - if (oppBase.is_real) { - return oppBase; + int unconnectedLocationsSize = 0; + const std::array unconnectedLocations = getUnconnectedLocations(loc, opp, unconnectedLocationsSize); + if (unconnectedLocationsSize >= 1) { + bases.clear(); + tryCapture(loc, opp, unconnectedLocations, unconnectedLocationsSize, isGrounded, bases); + // The found base always should be real and include the `iniLoc` + for (const Base& oppBase : bases) { + if (oppBase.is_real) { + return; + } } } } assert(false && "Opp empty territory should be enclosed by an outer closure"); - return {}; } -vector Board::tryCapture(const Loc loc, const Player pla, const bool emptyBaseCapturing, bool& atLeastOneRealBaseIsGrounded) { - getUnconnectedLocations(loc, pla); +void Board::tryCapture( + const Loc loc, + const Player pla, + const std::array& unconnectedLocations, + const int unconnectedLocationsSize, + bool& atLeastOneRealBaseIsGrounded, + std::vector& bases) { auto currentClosures = vector>(); - if (const int minNumberOfConnections = emptyBaseCapturing ? 1 : 2; - unconnectedLocationsBufferSize < minNumberOfConnections) return {}; - - for (int index = 0; index < unconnectedLocationsBufferSize; index++) { - const Loc unconnectedLoc = unconnectedLocationsBuffer[index]; + for (int index = 0; index < unconnectedLocationsSize; index++) { + const Loc unconnectedLoc = unconnectedLocations[index]; // Optimization: it doesn't make sense to check the latest unconnected dot // when all previous connections form minimal bases // because the latest always forms a base with maximal square that should be dropped if (const size_t closuresSize = currentClosures.size(); - closuresSize > 0 && closuresSize == unconnectedLocationsBuffer.size() - 1) { + closuresSize > 0 && closuresSize == unconnectedLocations.size() - 1) { break; } @@ -483,11 +547,10 @@ vector Board::tryCapture(const Loc loc, const Player pla, const boo } } - auto resultBases = vector(); atLeastOneRealBaseIsGrounded = false; for (const vector& currentClosure: currentClosures) { Base base = buildBase(currentClosure, pla); - resultBases.push_back(base); + bases.push_back(base); if (!atLeastOneRealBaseIsGrounded && base.is_real) { for (const Loc& closureLoc : currentClosure) { @@ -500,13 +563,10 @@ vector Board::tryCapture(const Loc loc, const Player pla, const boo } } } - - return std::move(resultBases); } -vector Board::ground(const Player pla, vector& emptyBaseInvalidatePositions) { +void Board::ground(const Player pla, vector& emptyBaseInvalidatePositions, vector& bases) { const Color opp = getOpp(pla); - auto resultBases = vector(); for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { @@ -523,33 +583,34 @@ vector Board::ground(const Player pla, vector& emptyBaseInvali } } - resultBases.push_back(createBaseAndUpdateStates(opp, true)); + bases.push_back(createBaseAndUpdateStates(opp, true)); } } } - - return resultBases; } -void Board::getUnconnectedLocations(const Loc loc, const Player pla) const { - const Loc xm1y = loc + adj_offsets[LEFT_INDEX]; - const Loc xym1 = loc + adj_offsets[TOP_INDEX]; - const Loc xp1y = loc + adj_offsets[RIGHT_INDEX]; - const Loc xyp1 = loc + adj_offsets[BOTTOM_INDEX]; +std::array Board::getUnconnectedLocations(const Loc loc, const Player pla, int& size) const { + const Loc xm1y = Location::xm1y(loc); + const Loc xym1 = Location::xym1(loc, x_size); + const Loc xp1y = Location::xp1y(loc); + const Loc xyp1 = Location::xyp1(loc, x_size); + + std::array unconnectedLocationsBuffer; + size = 0; + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xp1y), pla, Location::xp1yp1(loc, x_size), xyp1); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xyp1), pla, Location::xm1yp1(loc, x_size), xm1y); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xm1y), pla, Location::xm1ym1(loc, x_size), xym1); + checkAndAddUnconnectedLocation(unconnectedLocationsBuffer, size, getColor(xym1), pla, Location::xp1ym1(loc, x_size), xp1y); - unconnectedLocationsBufferSize = 0; - checkAndAddUnconnectedLocation(getColor(xp1y), pla, loc + adj_offsets[RIGHT_BOTTOM_INDEX], xyp1); - checkAndAddUnconnectedLocation(getColor(xyp1), pla, loc + adj_offsets[LEFT_BOTTOM_INDEX], xm1y); - checkAndAddUnconnectedLocation(getColor(xm1y), pla, loc + adj_offsets[LEFT_TOP_INDEX], xym1); - checkAndAddUnconnectedLocation(getColor(xym1), pla, loc + adj_offsets[RIGHT_TOP_INDEX], xp1y); + return unconnectedLocationsBuffer; } -void Board::checkAndAddUnconnectedLocation(const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { +void Board::checkAndAddUnconnectedLocation(std::array& unconnectedLocationsBuffer, int& size, const Player checkPla, const Player currentPla, const Loc addLoc1, const Loc addLoc2) const { if (checkPla != currentPla) { if (getColor(addLoc1) == currentPla) { - unconnectedLocationsBuffer[unconnectedLocationsBufferSize++] = addLoc1; + unconnectedLocationsBuffer[size++] = addLoc1; } else if (getColor(addLoc2) == currentPla) { - unconnectedLocationsBuffer[unconnectedLocationsBufferSize++] = addLoc2; + unconnectedLocationsBuffer[size++] = addLoc2; } } } @@ -830,7 +891,7 @@ void Board::makeMoveAndCalculateCapturesAndBases( vector& bases ) const { if(isLegal(loc, pla, isSuicideLegal, false)) { - MoveRecord moveRecord = const_cast(this)->playMoveAssumeLegalDots(loc, pla); + MoveRecord moveRecord = const_cast(this)->playMoveRecordedDots(loc, pla); if(!moveRecord.bases.empty()) { if(moveRecord.bases[0].pla == pla) { diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index 6c296334e..8b56fc3a9 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -233,13 +233,18 @@ void runDotsStressTestsInternal( int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * initialBoard.numLegalMoves; Player pla = P_BLACK; + int currentGameMovesCount = 0; for(short randomMove : randomMoves) { - lastLoc = moveRecords.size() >= tryGroundingAfterMove ? Board::PASS_LOC : randomMove; + lastLoc = currentGameMovesCount >= tryGroundingAfterMove ? Board::PASS_LOC : randomMove; if (board.isLegal(lastLoc, pla, suicideAllowed, false)) { - Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); - movesCount++; - moveRecords.push_back(moveRecord); + if (performExtraChecks) { + Board::MoveRecord moveRecord = board.playMoveRecorded(lastLoc, pla); + moveRecords.push_back(moveRecord); + } else { + board.playMoveAssumeLegal(lastLoc, pla); + } + currentGameMovesCount++; pla = getOpp(pla); } @@ -247,8 +252,7 @@ void runDotsStressTestsInternal( groundingCount++; int scoreDiff; int oppScoreIfGrounding; - Player lastPla = moveRecords.back().pla; - if (lastPla == P_BLACK) { + if (Player lastPla = getOpp(pla); lastPla == P_BLACK) { scoreDiff = board.numBlackCaptures - board.numWhiteCaptures; oppScoreIfGrounding = board.whiteScoreIfBlackGrounds; } else { @@ -276,6 +280,7 @@ void runDotsStressTestsInternal( testAssert(0 == board.numLegalMoves); } + movesCount += currentGameMovesCount; if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { whiteWinsCount++; } else if (whiteScore < 0) { From aa692eca88622a21319c4eaa7c713b58d0278a0d Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:44:57 +0100 Subject: [PATCH 12/42] Introduce `BoardHistory.isPassAliveFinished` Fill app info (katago) during writing SGF --- cpp/dataio/sgf.cpp | 6 ++++++ cpp/game/boardhistory.cpp | 22 +++++++++++++++++----- cpp/game/boardhistory.h | 2 ++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 38aff7cf3..7269a8ad3 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1980,6 +1980,8 @@ void WriteSgf::writeSgf( int xSize = initialBoard.x_size; int ySize = initialBoard.y_size; out << "(;FF[4]"; + out << "AP[katago]"; + out << "GM[" << (rules.isDots ? "40" : "1") << "]"; if(xSize == ySize) out << "SZ[" << xSize << "]"; @@ -2082,6 +2084,10 @@ void WriteSgf::writeSgf( assert(endHist.moveHistory.size() <= startTurnIdx + gameData->whiteValueTargetsByTurn.size()); } + if (endHist.isPassAliveFinished) { + commentOut << "," << "passAliveFinished=true"; + } + if(extraComments.size() > 0) { if(commentOut.str().length() > 0) commentOut << " "; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index c260abc85..fe109ba41 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -58,7 +58,8 @@ BoardHistory::BoardHistory(const Rules& rules) finalWhiteMinusBlackScore(0.0f), isScored(false), isNoResult(false), - isResignation(false) { + isResignation(false), + isPassAliveFinished(false) { for(int i = 0; i < NUM_RECENT_BOARDS; i++) { recentBoards.emplace_back(rules); } @@ -104,7 +105,7 @@ BoardHistory::BoardHistory(const Board& board, Player pla, const Rules& r, int e hasButton(false), isPastNormalPhaseEnd(false), isGameFinished(false),winner(C_EMPTY),finalWhiteMinusBlackScore(0.0f), - isScored(false),isNoResult(false),isResignation(false) + isScored(false),isNoResult(false),isResignation(false),isPassAliveFinished(false) { for(int i = 0; i < NUM_RECENT_BOARDS; i++) { recentBoards.emplace_back(rules); @@ -147,7 +148,7 @@ BoardHistory::BoardHistory(const BoardHistory& other) hasButton(other.hasButton), isPastNormalPhaseEnd(other.isPastNormalPhaseEnd), isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), - isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) + isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation),isPassAliveFinished(other.isPassAliveFinished) { recentBoards = other.recentBoards; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; @@ -199,6 +200,7 @@ BoardHistory& BoardHistory::operator=(const BoardHistory& other) isScored = other.isScored; isNoResult = other.isNoResult; isResignation = other.isResignation; + isPassAliveFinished = other.isPassAliveFinished; return *this; } @@ -231,7 +233,7 @@ BoardHistory::BoardHistory(BoardHistory&& other) noexcept hasButton(other.hasButton), isPastNormalPhaseEnd(other.isPastNormalPhaseEnd), isGameFinished(other.isGameFinished),winner(other.winner),finalWhiteMinusBlackScore(other.finalWhiteMinusBlackScore), - isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation) + isScored(other.isScored),isNoResult(other.isNoResult),isResignation(other.isResignation),isPassAliveFinished(other.isPassAliveFinished) { recentBoards = other.recentBoards; wasEverOccupiedOrPlayed = other.wasEverOccupiedOrPlayed; @@ -280,6 +282,7 @@ BoardHistory& BoardHistory::operator=(BoardHistory&& other) noexcept isScored = other.isScored; isNoResult = other.isNoResult; isResignation = other.isResignation; + isPassAliveFinished = other.isPassAliveFinished; return *this; } @@ -324,6 +327,7 @@ void BoardHistory::clear(const Board& board, Player pla, const Rules& r, int ePh isScored = false; isNoResult = false; isResignation = false; + isPassAliveFinished = false; if (!rules.isDots) { for(int y = 0; y Date: Wed, 17 Sep 2025 22:28:12 +0200 Subject: [PATCH 13/42] Initial support of handicap games (#1) --- cpp/dataio/sgf.cpp | 22 ++++++++++++---------- cpp/game/boardhistory.cpp | 22 ++++++++++++++++------ cpp/game/rules.cpp | 9 +++++---- cpp/tests/testdotsstartposes.cpp | 21 +++++++++++++-------- 4 files changed, 46 insertions(+), 28 deletions(-) diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 7269a8ad3..2ade09b50 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1990,16 +1990,18 @@ void WriteSgf::writeSgf( out << "PB[" << bName << "]"; out << "PW[" << wName << "]"; - if (!rules.isDots) { - if(gameData != NULL) { - out << "HA[" << gameData->handicapForSgf << "]"; - } - else { - BoardHistory histCopy(endHist); - //Always use true for computing the handicap value that goes into an sgf - histCopy.setAssumeMultipleStartingBlackMovesAreHandicap(true); - out << "HA[" << histCopy.computeNumHandicapStones() << "]"; - } + int handicap = 0; + if(gameData != nullptr) { + handicap = gameData->handicapForSgf; + } + else { + BoardHistory histCopy(endHist); + //Always use true for computing the handicap value that goes into an sgf + histCopy.setAssumeMultipleStartingBlackMovesAreHandicap(true); + handicap = histCopy.computeNumHandicapStones(); + } + if (!rules.isDots || handicap != 0) { // Preserve backward compatibility and always fill `HA` for Go Game + out << "HA[" << handicap << "]"; } out << "KM[" << rules.komi << "]"; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index fe109ba41..879250d54 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -395,18 +395,28 @@ void BoardHistory::setOverrideNumHandicapStones(int n) { whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); } -static int numHandicapStonesOnBoardHelper(const Board& board, int blackNonPassTurnsToStart) { +static int numHandicapStonesOnBoardHelper(const Board& board, const int blackNonPassTurnsToStart) { int startBoardNumBlackStones = 0; int startBoardNumWhiteStones = 0; + + // Ignore start pos that is generated according to rules + const auto startPos = board.rules.generateStartPos(board.rules.startPos, board.x_size, board.y_size); + set startLocs; + for (auto move : startPos) { + startLocs.insert(move.loc); + } + for(int y = 0; y Rules::generateStartPos(const int startPos, const int x_size, offsetY = (y_size - 3) / 3; } // Consider the index and size of the cross - const int sideOffsetX = (x_size - 1) - (offsetX + 1); - const int sideOffsetY = (y_size - 1) - (offsetY + 1); + const int sideOffsetX = x_size - 1 - (offsetX + 1); + const int sideOffsetY = y_size - 1 - (offsetY + 1); addCross(offsetX, offsetY, x_size, false, moves); addCross(sideOffsetX, offsetY, x_size, false, moves); addCross(sideOffsetX, sideOffsetY, x_size, false, moves); @@ -792,13 +792,14 @@ int Rules::tryRecognizeStartPos(int size_x, int size_y, vector& placementM auto generateStartPosSortAndCompare = [&](const int startPos) -> bool { auto startPosMoves = generateStartPos(startPos, size_x, size_y); - if(startPosMoves.size() != placementMoves.size()) { + // Detect start pos properly in case of a handicap (the placement has more stones than expected start pos) + if (startPosMoves.size() > placementMoves.size()) { return false; } sortByLoc(startPosMoves); - for(size_t i = 0; i < placementMoves.size(); i++) { + for(size_t i = 0; i < startPosMoves.size(); i++) { if(placementMoves[i].loc != startPosMoves[i].loc || placementMoves[i].pla != startPosMoves[i].pla) return false; } diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index b2d9aa774..74d76b0c6 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -22,10 +22,13 @@ void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const Board& boar testAssert(startPos == newRules.startPos); } -void checkStartPos(const string& description, const int startPos, const int x_size, const int y_size, const string& expectedBoard) { +void checkStartPos(const string& description, const int startPos, const int x_size, const int y_size, const string& expectedBoard, const vector& extraMoves = {}) { cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; auto board = Board(x_size, y_size, Rules(true, startPos, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + for (const XYMove& extraMove : extraMoves) { + board.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, board.x_size), extraMove.player); + } std::ostringstream oss; Board::printBoard(oss, board, Board::NULL_LOC, nullptr); @@ -53,18 +56,20 @@ HASH: EC100709447890A116AFC8952423E3DD 1 O X )"); + checkStartPos("Extra dots with cross (for instance, a handicap game)", Rules::START_POS_CROSS, 4, 4, R"( +HASH: A130436FBD93FF473AB4F3B84DD304DB + 1 2 3 4 + 4 . . . . + 3 . X O . + 2 . O X . + 1 . . X . +)", {XYMove(2, 3, P_BLACK)}); + checkStartPosNotRecognized("Not enough dots for cross", R"( .... .xo. .o.. .... -)"); - - checkStartPosNotRecognized("Extra dots with cross", R"( -.... -.xo. -.ox. -..o. )"); checkStartPosNotRecognized("Reversed cross shouldn't be recognized", R"( From ea1abdd2fcb4455ca246d399e1902dbe594eb39d Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 21 Sep 2025 13:26:04 +0200 Subject: [PATCH 14/42] Don't use `doEndGameIfAllPassAlive` in the main game loop because the game always can be finished by grounding And it's unclear how to calculate the resulting score if the game is not finished by grounding, but the winner is already known --- cpp/program/play.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index fb327f953..51383c5ff 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -2447,10 +2447,15 @@ FinishedGameData* GameRunner::runGame( clearBotBeforeSearchThisGame = true; } - //In 2% of games, don't autoterminate the game upon all pass alive, to just provide a tiny bit of training data on positions that occur - //as both players must wrap things up manually, because within the search we don't autoterminate games, meaning that the NN will get - //called on positions that occur after the game would have been autoterminated. - bool doEndGameIfAllPassAlive = playSettings.forSelfPlay ? gameRand.nextBool(0.98) : true; + bool doEndGameIfAllPassAlive; + if (rules.isDots) { + doEndGameIfAllPassAlive = false; + } else { + //In 2% of games, don't autoterminate the game upon all pass alive, to just provide a tiny bit of training data on positions that occur + //as both players must wrap things up manually, because within the search we don't autoterminate games, meaning that the NN will get + //called on positions that occur after the game would have been autoterminated. + doEndGameIfAllPassAlive = playSettings.forSelfPlay ? gameRand.nextBool(0.98) : true; + } Search* botB; Search* botW; From 5b82bbc1cec5875c4fdd4da3c75a58cd50d83b05 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 21 Sep 2025 17:17:32 +0200 Subject: [PATCH 15/42] Implement randomization of start pos, fix #2 --- cpp/dataio/sgf.cpp | 23 ++-- cpp/game/board.cpp | 50 +++++--- cpp/game/board.h | 13 +- cpp/game/boardhistory.cpp | 5 +- cpp/game/common.h | 2 + cpp/game/rules.cpp | 201 ++++++++++++++++++++++++------- cpp/game/rules.h | 27 +++-- cpp/neuralnet/nninputs.cpp | 72 +++++++---- cpp/program/play.cpp | 11 ++ cpp/program/play.h | 1 + cpp/tests/testdotsbasic.cpp | 2 +- cpp/tests/testdotsextra.cpp | 24 +++- cpp/tests/testdotsstartposes.cpp | 68 +++++++---- cpp/tests/testdotsstress.cpp | 37 +++--- cpp/tests/testdotsutils.cpp | 6 +- cpp/tests/testdotsutils.h | 4 +- 16 files changed, 391 insertions(+), 155 deletions(-) diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 2ade09b50..5de990129 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -151,7 +151,11 @@ static Rules getRulesFromSgf(const SgfNode& rootNode, const int xSize, const int vector placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); - rules.startPos = Rules::tryRecognizeStartPos(xSize, ySize, placementMoves, true); + bool randomized; + rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, true, randomized); + if (randomized && !rules.startPosIsRandom) { + propertyFail("Defined start pos is randomized but RU says it shouldn't"); + } return rules; } @@ -1769,7 +1773,11 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); - rules.startPos = Rules::tryRecognizeStartPos(xSize, ySize, placementMoves, true); + bool randomized; + rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, true, randomized); + if (randomized && !rules.startPosIsRandom) { + f("Defined start pos is randomized but RU says it shouldn't"); + } return rules; } @@ -1798,13 +1806,14 @@ BoardHistory CompactSgf::setupInitialBoardAndHist(const Rules& initialRules, Pla nextPla = moves[0].pla; auto board = Board(xSize,ySize,initialRules); - if (initialRules.startPos == Rules::START_POS_EMPTY) { - bool suc = board.setStonesFailIfNoLibs(placements); - if(!suc) + const int numOfStartPosStones = initialRules.getNumOfStartPosStones(); + for (int i = 0; i < placements.size(); i++) { + const Move placement = placements[i]; + if(const bool suc = board.setStoneFailIfNoLibs(placement.loc, placement.pla, i < numOfStartPosStones); !suc) throw StringError("setupInitialBoardAndHist: initial board position contains invalid stones or zero-liberty stones"); } - BoardHistory hist = BoardHistory(board,nextPla,initialRules,0); - if (int numStonesOnBoard = board.numStonesOnBoard(); hist.initialTurnNumber < numStonesOnBoard) + auto hist = BoardHistory(board,nextPla,initialRules,0); + if (const int numStonesOnBoard = board.numStonesOnBoard(); hist.initialTurnNumber < numStonesOnBoard) hist.initialTurnNumber = numStonesOnBoard; return hist; } diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 0fc43b285..4e3ad4f93 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -12,8 +12,6 @@ #include #include -#include "../core/rand.h" - using namespace std; //STATIC VARS----------------------------------------------------------------------------- @@ -137,11 +135,9 @@ Board::Board(const Board& other) { memcpy(colors, other.colors, sizeof(Color)*MAX_ARR_SIZE); ko_loc = other.ko_loc; - if (!other.rules.isDots) { - chain_data = other.chain_data; - chain_head = other.chain_head; - next_in_chain = other.next_in_chain; - } + chain_data = other.chain_data; + chain_head = other.chain_head; + next_in_chain = other.next_in_chain; // empty_list = other.empty_list; pos_hash = other.pos_hash; @@ -150,6 +146,7 @@ Board::Board(const Board& other) { blackScoreIfWhiteGrounds = other.blackScoreIfWhiteGrounds; whiteScoreIfBlackGrounds = other.whiteScoreIfBlackGrounds; numLegalMoves = other.numLegalMoves; + start_pos_moves = other.start_pos_moves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); visited_data.resize(other.visited_data.size(), false); } @@ -198,11 +195,6 @@ void Board::init(const int xS, const int yS, const Rules& initRules) } Location::getAdjacentOffsets(adj_offsets, x_size, isDots()); - - const vector placement = Rules::generateStartPos(rules.startPos, x_size, y_size); - for (const Move& move : placement) { - playMoveAssumeLegal(move.loc, move.pla); - } } void Board::initHash() @@ -757,7 +749,7 @@ bool Board::setStone(Loc loc, Color color) return true; } -bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { +bool Board::setStoneFailIfNoLibs(Loc loc, Color color, const bool startPos) { Color colorAtLoc = getColor(loc); if(loc < 0 || loc >= MAX_ARR_SIZE || colorAtLoc == C_WALL) return false; @@ -771,12 +763,18 @@ bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { if(isSuicide(loc,color) || wouldBeCapture(loc,color)) return false; playMoveAssumeLegal(loc,color); + if (startPos) { + start_pos_moves.emplace_back(loc, color); + } } else if(color == C_EMPTY) removeSingleStone(loc); else { assert(colorAtLoc == getOpp(color)); removeSingleStone(loc); + if (startPos) { + start_pos_moves.emplace_back(loc, color); + } if(isSuicide(loc,color) || wouldBeCapture(loc,color)) { playMoveAssumeLegal(loc,getOpp(color)); ko_loc = oldKoLoc; @@ -789,25 +787,31 @@ bool Board::setStoneFailIfNoLibs(Loc loc, Color color) { return true; } -bool Board::setStonesFailIfNoLibs(std::vector placements) { +void Board::setStartPos(Rand& rand) { + const vector startPos = Rules::generateStartPos(rules.startPos, rules.startPosIsRandom ? &rand : nullptr, x_size, y_size); + bool success = setStonesFailIfNoLibs(startPos, true); + assert(success); +} + +bool Board::setStonesFailIfNoLibs(const std::vector& placements, const bool startPos) { std::set locs; for(const Move& placement: placements) { if(locs.find(placement.loc) != locs.end()) return false; locs.insert(placement.loc); } + //First empty out all locations that we plan to set. //This guarantees avoiding any intermediate liberty issues. for(const Move& placement: placements) { - bool suc = setStoneFailIfNoLibs(placement.loc, C_EMPTY); - if(!suc) + if(bool suc = setStoneFailIfNoLibs(placement.loc, C_EMPTY, startPos); !suc) return false; } //Now set all the stones we wanted. for(const Move& placement: placements) { - bool suc = setStoneFailIfNoLibs(placement.loc, placement.pla); - if(!suc) + if(bool suc = setStoneFailIfNoLibs(placement.loc, placement.pla, startPos); !suc) { return false; + } } return true; } @@ -2509,6 +2513,16 @@ bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool ch if (numLegalMoves != other.numLegalMoves) { return false; } + if (start_pos_moves.size() != other.start_pos_moves.size()) { + return false; + } + for (int i = 0; i < start_pos_moves.size(); i++) { + const Move start_pose_move = start_pos_moves[i]; + const Move other_start_pos_move = other.start_pos_moves[i]; + if (start_pose_move.loc != other_start_pos_move.loc || start_pose_move.pla != other_start_pos_move.pla) { + return false; + } + } if (checkRules && rules != other.rules) { return false; } diff --git a/cpp/game/board.h b/cpp/game/board.h index d16c7aa53..f162a89bf 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -7,12 +7,11 @@ #ifndef GAME_BOARD_H_ #define GAME_BOARD_H_ -#include - #include "../core/global.h" #include "../core/hash.h" #include "../external/nlohmann_json/json.hpp" #include "rules.h" +#include "../core/rand.h" #ifndef COMPILE_MAX_BOARD_LEN #define COMPILE_MAX_BOARD_LEN 39 @@ -307,15 +306,20 @@ struct Board //Returns false if location or color were out of range. bool setStone(Loc loc, Color color); + // Set the start pos and use the provided random in case of randomization is used + // It should be called strictly before handicap placement + void setStartPos(Rand& rand); //Sets the specified stone, including overwriting existing stones, but only if doing so will //not result in any captures or zero liberty groups. //Returns false if location or color were out of range, or if would cause a zero liberty group. //In case of failure, will restore the position, but may result in chain ids or ordering in the board changing. - bool setStoneFailIfNoLibs(Loc loc, Color color); + //If startPos is true, adds the move to start pos moves to distinguish between start pos and handicap stones + bool setStoneFailIfNoLibs(Loc loc, Color color, bool startPos = false); //Same, but sets multiple stones, and only requires that the final configuration contain no zero-liberty groups. //If it does contain a zero liberty group, fails and returns false and leaves the board in an arbitrarily changed but valid state. //Also returns false if any location is specified more than once. - bool setStonesFailIfNoLibs(std::vector placements); + //If startPos is true, adds the placements to start pos moves to distinguish between start pos and handicap stones + bool setStonesFailIfNoLibs(const std::vector& placements, bool startPos = false); //Attempts to play the specified move. Returns true if successful, returns false if the move was illegal. bool playMove(Loc loc, Player pla, bool isMultiStoneSuicideLegal); @@ -425,6 +429,7 @@ struct Board std::vector chain_data; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. std::vector chain_head; //Where is the head of this chain? Undefined if EMPTY or WALL std::vector next_in_chain; //Location of next stone in chain. Circular linked list. Undefined if EMPTY or WALL + std::vector start_pos_moves; //Moves that are played at the very beginning of the game private: diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 879250d54..5fb435b60 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -399,10 +399,9 @@ static int numHandicapStonesOnBoardHelper(const Board& board, const int blackNon int startBoardNumBlackStones = 0; int startBoardNumWhiteStones = 0; - // Ignore start pos that is generated according to rules - const auto startPos = board.rules.generateStartPos(board.rules.startPos, board.x_size, board.y_size); + // Ignore start pos moves that are generated according to rules set startLocs; - for (auto move : startPos) { + for (auto move : board.start_pos_moves) { startLocs.insert(move.loc); } diff --git a/cpp/game/common.h b/cpp/game/common.h index cde12ddea..95ca124b4 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -7,7 +7,9 @@ const std::string DOTS_KEY = "dots"; const std::string DOTS_CAPTURE_EMPTY_BASE_KEY = "dotsCaptureEmptyBase"; const std::string DOTS_CAPTURE_EMPTY_BASES_KEY = "dotsCaptureEmptyBases"; const std::string START_POS_KEY = "startPos"; +const std::string START_POS_RANDOM_KEY = "startPosIsRandom"; const std::string START_POSES_KEY = "startPoses"; +const std::string START_POSES_ARE_RANDOM_KEY = "startPosesAreRandom"; const std::string BLACK_SCORE_IF_WHITE_GROUNDS_KEY = "blackScoreIfWhiteGrounds"; const std::string WHITE_SCORE_IF_BLACK_GROUNDS_KEY = "whiteScoreIfBlackGrounds"; diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index 7f91e80c9..8f3da0d06 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -4,6 +4,7 @@ #include +#include "../core/rand.h" #include "board.h" using namespace std; @@ -14,12 +15,13 @@ const Rules Rules::DEFAULT_GO = Rules(false); Rules::Rules() : Rules(false) {} -Rules::Rules(const bool initIsDots, const int startPos, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : - Rules(initIsDots, startPos, 0, 0, 0, true, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} +Rules::Rules(const bool initIsDots, const int startPos, const bool startPosIsRandom, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : + Rules(initIsDots, startPos, startPosIsRandom, 0, 0, 0, true, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} Rules::Rules(const bool initIsDots) : Rules( initIsDots, - initIsDots ? START_POS_EMPTY : 0, + initIsDots ? START_POS_CROSS : 0, + false, initIsDots ? 0 : KO_POSITIONAL, initIsDots ? 0 : SCORING_AREA, initIsDots ? 0 : TAX_NONE, @@ -42,12 +44,13 @@ Rules::Rules( int whbRule, bool pOk, float km -) : Rules(false, 0, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { +) : Rules(false, 0, false, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { } Rules::Rules( bool isDots, int startPosRule, + bool startPosIsRandom, int kRule, int sRule, int tRule, @@ -61,6 +64,7 @@ Rules::Rules( ) : isDots(isDots), startPos(startPosRule), + startPosIsRandom(startPosIsRandom), dotsCaptureEmptyBases(dotsCaptureEmptyBases), dotsFreeCapturedDots(dotsFreeCapturedDots), koRule(kRule), @@ -93,6 +97,7 @@ bool Rules::equals(const Rules& other, const bool ignoreSgfDefinedProps) const { return (ignoreSgfDefinedProps ? true : isDots == other.isDots) && (ignoreSgfDefinedProps ? true : startPos == other.startPos) && + startPosIsRandom == other.startPosIsRandom && koRule == other.koRule && scoringRule == other.scoringRule && taxRule == other.taxRule && @@ -160,7 +165,9 @@ set Rules::startPosStrings() { int Rules::getNumOfStartPosStones() const { switch (startPos) { - case START_POS_EMPTY: return 0; + case START_POS_EMPTY: + case START_POS_CUSTOM: + return 0; case START_POS_CROSS: return 4; case START_POS_CROSS_2: return 8; case START_POS_CROSS_4: return 16; @@ -278,6 +285,9 @@ string Rules::toString(const bool includeSgfDefinedProperties) const { if (includeSgfDefinedProperties && startPos != START_POS_EMPTY) { out << START_POS_KEY << writeStartPosRule(startPos); } + if (startPosIsRandom) { + out << START_POS_RANDOM_KEY << startPosIsRandom; + } out << "sui" << multiStoneSuicideLegal; if (!isDots) { if (hasButton != DEFAULT_GO.hasButton) @@ -315,6 +325,8 @@ json Rules::toJsonHelper(bool omitKomi, bool omitDefaults) const { } else { if (!omitDefaults || dotsCaptureEmptyBases != DEFAULT_DOTS.dotsCaptureEmptyBases) ret[DOTS_CAPTURE_EMPTY_BASE_KEY] = dotsCaptureEmptyBases; + if (!omitDefaults || startPosIsRandom != DEFAULT_DOTS.startPosIsRandom) + ret[START_POS_RANDOM_KEY] = startPosIsRandom; } if(!omitKomi) ret["komi"] = komi; @@ -348,7 +360,7 @@ string Rules::toJsonStringNoKomiMaybeOmitStuff() const { Rules Rules::updateRules(const string& k, const string& v, Rules oldRules) { Rules rules = oldRules; string key = Global::trim(k); - string value = Global::trim(Global::toUpper(v)); + const string value = Global::trim(Global::toUpper(v)); if(key == DOTS_KEY) rules.isDots = Global::stringToBool(value); else if(key == "ko") rules.koRule = Rules::parseKoRule(value); else if(key == "score") rules.scoringRule = Rules::parseScoringRule(value); @@ -485,6 +497,8 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) string key = iter.key(); if (key == START_POS_KEY) rules.startPos = Rules::parseStartPos(iter.value().get()); + else if (key == START_POS_RANDOM_KEY) + rules.startPosIsRandom = iter.value().get(); else if (key == DOTS_CAPTURE_EMPTY_BASE_KEY) rules.dotsCaptureEmptyBases = iter.value().get(); else if(key == "ko") @@ -630,6 +644,12 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) else throw IOError("Could not parse rules: " + sOrig); continue; } + if(startsWithAndStrip(s,START_POS_RANDOM_KEY)) { + if(startsWithAndStrip(s,"1")) rules.startPosIsRandom = true; + else if(startsWithAndStrip(s,"0")) rules.startPosIsRandom = false; + else throw IOError("Could not parse rules: " + sOrig); + continue; + } //Unknown rules format else throw IOError("Could not parse rules: " + sOrig); @@ -703,7 +723,19 @@ string Rules::toStringNoSgfDefinedPropertiesMaybeNice() const { return toStringNoSgfDefinedProps(); } -std::vector Rules::generateStartPos(const int startPos, const int x_size, const int y_size) { +double nextRandomOffset(Rand& rand) { + return rand.nextDouble(4, 7); +} + +int nextRandomOffsetX(Rand& rand, int x_size) { + return static_cast(round(nextRandomOffset(rand) / 39.0 * x_size)); +} + +int nextRandomOffsetY(Rand& rand, int y_size) { + return static_cast(round(nextRandomOffset(rand) / 32.0 * y_size)); +} + +std::vector Rules::generateStartPos(const int startPos, Rand* rand, const int x_size, const int y_size) { std::vector moves; switch (startPos) { case START_POS_EMPTY: @@ -724,22 +756,51 @@ std::vector Rules::generateStartPos(const int startPos, const int x_size, break; case START_POS_CROSS_4: if (x_size >= 4 && y_size >= 4) { - int offsetX; - int offsetY; - if (x_size == 39 && y_size == 32) { - offsetX = 11; - offsetY = 10; + int offsetX1; + int offsetY1; + int offsetX2; + int offsetY2; + int offsetX3; + int offsetY3; + int offsetX4; + int offsetY4; + + if (rand == nullptr) { + if (x_size == 39 && y_size == 32) { + offsetX1 = 11; + offsetY1 = 10; + } else { + offsetX1 = (x_size - 3) / 3; + offsetY1 = (y_size - 3) / 3; + } + // Consider the index and size of the cross + const int sideOffsetX = x_size - 1 - (offsetX1 + 1); + const int sideOffsetY = y_size - 1 - (offsetY1 + 1); + + offsetX2 = sideOffsetX; + offsetY2 = offsetY1; + offsetX3 = sideOffsetX; + offsetY3 = sideOffsetY; + offsetX4 = offsetX1; + offsetY4 = sideOffsetY; } else { - offsetX = (x_size - 3) / 3; - offsetY = (y_size - 3) / 3; + const int middleX = x_size / 2 - 1; + const int middleY = y_size / 2 - 1; + + offsetX1 = middleX - nextRandomOffsetX(*rand, x_size); + offsetY1 = middleY - nextRandomOffsetY(*rand, y_size); + offsetX2 = middleX + nextRandomOffsetX(*rand, x_size); + offsetY2 = middleY - nextRandomOffsetY(*rand, y_size); + offsetX3 = middleX + nextRandomOffsetX(*rand, x_size); + offsetY3 = middleY + nextRandomOffsetY(*rand, y_size); + offsetX4 = middleX - nextRandomOffsetX(*rand, x_size); + offsetY4 = middleY + nextRandomOffsetY(*rand, y_size); } - // Consider the index and size of the cross - const int sideOffsetX = x_size - 1 - (offsetX + 1); - const int sideOffsetY = y_size - 1 - (offsetY + 1); - addCross(offsetX, offsetY, x_size, false, moves); - addCross(sideOffsetX, offsetY, x_size, false, moves); - addCross(sideOffsetX, sideOffsetY, x_size, false, moves); - addCross(offsetX, sideOffsetY, x_size, false, moves); + + addCross(offsetX1, offsetY1, x_size, false, moves); + addCross(offsetX2, offsetY2, x_size, false, moves); + addCross(offsetX3, offsetY3, x_size, false, moves); + addCross(offsetX4, offsetY4, x_size, false, moves); } break; default: @@ -777,42 +838,98 @@ void Rules::addCross(const int x, const int y, const int x_size, const bool rota } } -int Rules::tryRecognizeStartPos(int size_x, int size_y, vector& placementMoves, const bool emptyIfFailed) { +int Rules::tryRecognizeStartPos( + const vector& placementMoves, + const int x_size, + const int y_size, + const bool emptyIfFailed, + bool& randomized) { + randomized = false; // Empty or unknown start pos is static by default + if(placementMoves.empty()) return START_POS_EMPTY; - int result = emptyIfFailed ? START_POS_EMPTY : -1; + int result = emptyIfFailed ? START_POS_EMPTY : START_POS_CUSTOM; + + const int stride = x_size + 1; + auto placement = vector(stride * (y_size + 2), C_EMPTY); + + for (const auto move : placementMoves) { + placement[move.loc] = move.pla; + } + + auto recognizedCrossesMoves = vector(); + + for (const auto move : placementMoves) { + const int x = Location::getX(move.loc, x_size); + const int y = Location::getY(move.loc, x_size); + + const int xy = Location::getLoc(x, y, x_size); + const Player pla = placement[xy]; + if (pla == C_EMPTY) continue; + + const Player opp = getOpp(pla); + if (opp == C_EMPTY) continue; + + if (x + 1 > x_size) continue; + const int xp1y = Location::getLoc(x + 1, y, x_size); + if (placement[xp1y] != opp) continue; - // Sort locs because initial pos is invariant to moves order + if (y + 1 > y_size) continue; + const int xp1yp1 = Location::getLoc(x + 1, y + 1, x_size); + if (placement[xp1yp1] != pla) continue; + + const int xyp1 = Location::getLoc(x, y + 1, x_size); + if (placement[xyp1] != opp) continue; + + recognizedCrossesMoves.emplace_back(xy, pla); + recognizedCrossesMoves.emplace_back(xp1y, opp); + recognizedCrossesMoves.emplace_back(xp1yp1, pla); + recognizedCrossesMoves.emplace_back(xyp1, opp); + + // Clear the placement because the recognized cross is already stored + placement[xy] = C_EMPTY; + placement[xp1y] = C_EMPTY; + placement[xp1yp1] = C_EMPTY; + placement[xyp1] = C_EMPTY; + } + + // Sort locs because start pos is invariant to moves order auto sortByLoc = [&](vector& moves) { std::sort(moves.begin(), moves.end(), [](const Move& move1, const Move& move2) { return move1.loc < move2.loc; }); }; - sortByLoc(placementMoves); + sortByLoc(recognizedCrossesMoves); - auto generateStartPosSortAndCompare = [&](const int startPos) -> bool { - auto startPosMoves = generateStartPos(startPos, size_x, size_y); + // Try to match strictly and set up randomized if failed. + auto detectRandomization = [&](const int expectedStartPos) -> void { + auto staticStartPosMoves = generateStartPos(expectedStartPos, nullptr, x_size, y_size); - // Detect start pos properly in case of a handicap (the placement has more stones than expected start pos) - if (startPosMoves.size() > placementMoves.size()) { - return false; - } + assert(staticStartPosMoves.size() == recognizedCrossesMoves.size()); - sortByLoc(startPosMoves); + sortByLoc(staticStartPosMoves); - for(size_t i = 0; i < startPosMoves.size(); i++) { - if(placementMoves[i].loc != startPosMoves[i].loc || placementMoves[i].pla != startPosMoves[i].pla) - return false; + for(size_t i = 0; i < staticStartPosMoves.size(); i++) { + if(staticStartPosMoves[i].loc != recognizedCrossesMoves[i].loc || staticStartPosMoves[i].pla != recognizedCrossesMoves[i].pla) { + randomized = true; + break; + } } - return true; + result = expectedStartPos; }; - if(generateStartPosSortAndCompare(START_POS_CROSS)) { - result = START_POS_CROSS; - } else if(generateStartPosSortAndCompare(START_POS_CROSS_2)) { - result = START_POS_CROSS_2; - } else if(generateStartPosSortAndCompare(START_POS_CROSS_4)) { - result = START_POS_CROSS_4; + switch (recognizedCrossesMoves.size()) { + case 4: + detectRandomization(START_POS_CROSS); + break; + case 8: + detectRandomization(START_POS_CROSS_2); + break; + case 16: + detectRandomization(START_POS_CROSS_4); + break; + default:; + break; } return result; diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 72d085c68..f55fdaff6 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -4,6 +4,7 @@ #include "common.h" #include "../core/global.h" #include "../core/hash.h" +#include "../core/rand.h" #include "../external/nlohmann_json/json.hpp" @@ -15,8 +16,12 @@ struct Rules { static constexpr int START_POS_CROSS = 1; static constexpr int START_POS_CROSS_2 = 2; static constexpr int START_POS_CROSS_4 = 3; + static constexpr int START_POS_CUSTOM = 4; int startPos; + // Enables random shuffling of start pos. Currently, it works only for CROSS_4 + bool startPosIsRandom; + static const int KO_SIMPLE = 0; static const int KO_POSITIONAL = 1; static const int KO_SITUATIONAL = 2; @@ -54,7 +59,7 @@ struct Rules { bool friendlyPassOk; Rules(); - Rules(bool initIsDots, int startPos, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); + Rules(bool initIsDots, int startPos, bool startPosIsRandom, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); explicit Rules(bool initIsDots); Rules( int koRule, @@ -108,15 +113,20 @@ struct Rules { static Rules updateRules(const std::string& key, const std::string& value, Rules priorRules); - static std::vector generateStartPos(int startPos, int x_size, int y_size); + static std::vector generateStartPos(int startPos, Rand* rand, int x_size, int y_size); /** - * @param size_x size of field - * @param size_y size of field - * @param placementMoves initial placement moves, they can be sorted, that's why it's not const - * @param emptyIfFailed - * @return -1 if the recognition is failed + * @param placementMoves placement moves that we are trying to recognize. + * @param x_size size of field + * @param y_size size of field + * @param emptyIfFailed returns empty start pos if recognition is failed. It's useful for detecting start pos from SGF when handicap stones are placed + * @param randomized if we recognize a start pos, but it doesn't match the strict position, set it up to `true` */ - static int tryRecognizeStartPos(int size_x, int size_y, std::vector& placementMoves, bool emptyIfFailed); + static int tryRecognizeStartPos( + const std::vector& placementMoves, + int x_size, + int y_size, + bool emptyIfFailed, + bool& randomized); friend std::ostream& operator<<(std::ostream& out, const Rules& rules); std::string toString() const; @@ -143,6 +153,7 @@ struct Rules { Rules( bool isDots, int startPosRule, + bool startPosIsRandom, int kRule, int sRule, int tRule, diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index e9a278d88..6dda69b43 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -681,34 +681,58 @@ Loc SymmetryHelpers::getSymLoc(Loc loc, int xSize, int ySize, int symmetry) { return getSymLoc(Location::getX(loc,xSize), Location::getY(loc,xSize), xSize, ySize, symmetry); } - Board SymmetryHelpers::getSymBoard(const Board& board, int symmetry) { - bool transpose = (symmetry & 0x4) != 0; - bool flipX = (symmetry & 0x2) != 0; - bool flipY = (symmetry & 0x1) != 0; - Board symBoard( - transpose ? board.y_size : board.x_size, - transpose ? board.x_size : board.y_size, - board.rules - ); + const bool transpose = (symmetry & 0x4) != 0; + const bool flipX = (symmetry & 0x2) != 0; + const bool flipY = (symmetry & 0x1) != 0; + const int sym_x_size = transpose ? board.y_size : board.x_size; + const int sym_y_size = transpose ? board.x_size : board.y_size; + + auto getSymLoc = [&](const int x, const int y) { + int symX = flipX ? board.x_size - x - 1 : x; + int symY = flipY ? board.y_size - y - 1 : y; + if(transpose) + std::swap(symX,symY); + return Location::getLoc(symX,symY,sym_x_size); + }; + + Rules symRules(board.rules); + vector sym_start_pos_moves; + + if (!board.start_pos_moves.empty()) { + sym_start_pos_moves.reserve(board.start_pos_moves.size()); + for (const auto start_pos_move : board.start_pos_moves) { + const Loc loc = start_pos_move.loc; + const int x = Location::getX(loc, board.x_size); + const int y = Location::getY(loc, board.x_size); + sym_start_pos_moves.emplace_back(getSymLoc(x, y), start_pos_move.pla); + } + bool randomized; + symRules.startPos = Rules::tryRecognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, false, randomized); + symRules.startPosIsRandom = randomized; + } + + Board symBoard(sym_x_size, sym_y_size, symRules); + symBoard.setStonesFailIfNoLibs(sym_start_pos_moves, true); + Loc symKoLoc = Board::NULL_LOC; for(int y = 0; y(allowedStartPosRules.size()))]; } + if (!allowedStartPosRandomRules.empty()) { + rules.startPosIsRandom = allowedStartPosRandomRules[rand.nextUInt(static_cast(allowedStartPosRandomRules.size()))]; + } if (dotsGame) { rules.dotsCaptureEmptyBases = allowedCaptureEmtpyBasesRules[rand.nextUInt(static_cast(allowedCaptureEmtpyBasesRules.size()))]; @@ -621,6 +630,8 @@ void GameInitializer::createGameSharedUnsynchronized( int xSize = allowedBSizes[bSizeIdx].first; int ySize = allowedBSizes[bSizeIdx].second; board = Board(xSize,ySize,rules); + board.setStartPos(rand); + pla = P_BLACK; hist.clear(board,pla,rules,0); hist.setInitialTurnNumber(rules.getNumOfStartPosStones()); diff --git a/cpp/program/play.h b/cpp/program/play.h index 9b7fde7e9..0849f8b58 100644 --- a/cpp/program/play.h +++ b/cpp/program/play.h @@ -139,6 +139,7 @@ class GameInitializer { bool dotsGame; std::vector allowedCaptureEmtpyBasesRules; std::vector allowedStartPosRules; + std::vector allowedStartPosRandomRules; std::vector allowedKoRuleStrs; std::vector allowedScoringRuleStrs; diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index caa297b2b..0aec8aea4 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -12,7 +12,7 @@ void checkDotsField(const string& description, const string& input, bool capture auto moveRecords = vector(); - Board initialBoard = parseDotsField(input, captureEmptyBases, freeCapturedDots, {}); + Board initialBoard = parseDotsField(input, false, captureEmptyBases, freeCapturedDots, {}); Board board = Board(initialBoard); diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp index 223aaba2d..3ad341be5 100644 --- a/cpp/tests/testdotsextra.cpp +++ b/cpp/tests/testdotsextra.cpp @@ -8,13 +8,13 @@ using namespace std; using namespace TestCommon; void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardInput, const vector& extraMoves, const int symmetry) { - Board transformedBoard = SymmetryHelpers::getSymBoard(initBoard, symmetry); + const Board transformedBoard = SymmetryHelpers::getSymBoard(initBoard, symmetry); Board expectedBoard = parseDotsFieldDefault(expectedSymmetryBoardInput); for (const XYMove& extraMove : extraMoves) { expectedBoard.playMoveAssumeLegal(SymmetryHelpers::getSymLoc(extraMove.x, extraMove.y, initBoard, symmetry), extraMove.player); } expect(SymmetryHelpers::symmetryToString(symmetry).c_str(), Board::toStringSimple(transformedBoard, '\n'), Board::toStringSimple(expectedBoard, '\n')); - testAssert(transformedBoard.isEqualForTesting(expectedBoard, true, true)); + testAssert(transformedBoard.isEqualForTesting(expectedBoard, true, true, true)); } void Tests::runDotsSymmetryTests() { @@ -104,6 +104,26 @@ xo.. )", { XYMove(4, 1, P_WHITE)}, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); + + cout << "Check dots symmetry with start pos" << endl; + auto board = Board(5, 4, Rules(true, Rules::START_POS_CROSS, false, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + board.setStartPos(DOTS_RANDOM); + board.playMoveAssumeLegal(Location::getLoc(1, 2, board.x_size), P_BLACK); + + const auto rotatedBoard = SymmetryHelpers::getSymBoard(board, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_X); + + auto expectedBoard = Board(4, 5, Rules(true, Rules::START_POS_CROSS, true, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 2, expectedBoard.x_size), P_WHITE, true); + expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 3, expectedBoard.x_size), P_BLACK, true); + expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 3, expectedBoard.x_size), P_WHITE, true); + expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 2, expectedBoard.x_size), P_BLACK, true); + expectedBoard.playMoveAssumeLegal(Location::getLoc(1, 1, expectedBoard.x_size), P_BLACK); + + expect("Dots symmetry with start pos", Board::toStringSimple(rotatedBoard, '\n'), Board::toStringSimple(expectedBoard, '\n')); + testAssert(rotatedBoard.isEqualForTesting(expectedBoard, true, true, true)); + + const auto unrotatedBoard = SymmetryHelpers::getSymBoard(rotatedBoard, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y); + testAssert(board.isEqualForTesting(unrotatedBoard, true, true, true)); } string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 74d76b0c6..63d736617 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -10,7 +10,7 @@ using namespace std; using namespace std::chrono; using namespace TestCommon; -void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const Board& board) { +void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startPosIsRandom, const Board& board) { std::ostringstream sgfStringStream; const BoardHistory boardHistory(board, P_BLACK, board.rules, 0); WriteSgf::writeSgf(sgfStringStream, "black", "white", boardHistory, {}); @@ -20,12 +20,14 @@ void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const Board& boar const auto deserializedSgf = Sgf::parse(sgfString); const Rules newRules = deserializedSgf->getRulesOrFail(); testAssert(startPos == newRules.startPos); + testAssert(startPosIsRandom == newRules.startPosIsRandom); } -void checkStartPos(const string& description, const int startPos, const int x_size, const int y_size, const string& expectedBoard, const vector& extraMoves = {}) { +void checkStartPos(const string& description, const int startPos, const bool startPosIsRandom, const int x_size, const int y_size, const string& expectedBoard = "", const vector& extraMoves = {}) { cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; - auto board = Board(x_size, y_size, Rules(true, startPos, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + auto board = Board(x_size, y_size, Rules(true, startPos, startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + board.setStartPos(DOTS_RANDOM); for (const XYMove& extraMove : extraMoves) { board.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, board.x_size), extraMove.player); } @@ -33,30 +35,45 @@ void checkStartPos(const string& description, const int startPos, const int x_si std::ostringstream oss; Board::printBoard(oss, board, Board::NULL_LOC, nullptr); - expect(description.c_str(), oss, expectedBoard); + if (!expectedBoard.empty()) { + expect(description.c_str(), oss, expectedBoard); + } - writeToSgfAndCheckStartPosFromSgfProp(startPos, board); + writeToSgfAndCheckStartPosFromSgfProp(startPos, startPosIsRandom, board); } -void checkStartPosNotRecognized(const string& description, const string& inputBoard) { - const Board board = parseDotsFieldDefault(inputBoard); +void checkStartPosRecognition(const string& description, const int expectedStartPos, const int startPosIsRandom, const string& inputBoard) { + const Board board = parseDotsField(inputBoard, startPosIsRandom, false, false, {}); cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; - writeToSgfAndCheckStartPosFromSgfProp(0, board); + writeToSgfAndCheckStartPosFromSgfProp(expectedStartPos, startPosIsRandom, board); +} + +void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { + const auto generatedMoves = Rules::generateStartPos(startPos, startPosIsRandom ? &DOTS_RANDOM : nullptr, 39, 32); + bool actualRandomized; + testAssert(startPos == Rules::tryRecognizeStartPos(generatedMoves, 39, 32, false, actualRandomized)); + // We can't reliably check in case of randomization is not detected because random generator can + // generate static poses in rare cases. + if (actualRandomized) { + testAssert(startPosIsRandom); + } } void Tests::runDotsStartPosTests() { cout << "Running dots start pos tests" << endl; - checkStartPos("Cross on minimal size", Rules::START_POS_CROSS, 2, 2, R"( + Rand rand("runDotsStartPosTests"); + + checkStartPos("Cross on minimal size", Rules::START_POS_CROSS, false, 2, 2, R"( HASH: EC100709447890A116AFC8952423E3DD 1 2 2 X O 1 O X )"); - checkStartPos("Extra dots with cross (for instance, a handicap game)", Rules::START_POS_CROSS, 4, 4, R"( + checkStartPos("Extra dots with cross (for instance, a handicap game)", Rules::START_POS_CROSS, false, 4, 4, R"( HASH: A130436FBD93FF473AB4F3B84DD304DB 1 2 3 4 4 . . . . @@ -65,21 +82,21 @@ HASH: A130436FBD93FF473AB4F3B84DD304DB 1 . . X . )", {XYMove(2, 3, P_BLACK)}); - checkStartPosNotRecognized("Not enough dots for cross", R"( + checkStartPosRecognition("Not enough dots for cross", 0, false, R"( .... .xo. .o.. .... )"); - checkStartPosNotRecognized("Reversed cross shouldn't be recognized", R"( + checkStartPosRecognition("Reversed cross should be recognized as random", Rules::START_POS_CROSS, true, R"( .... .ox. .xo. .... )"); - checkStartPos("Cross on odd size", Rules::START_POS_CROSS, 3, 3, R"( + checkStartPos("Cross on odd size", Rules::START_POS_CROSS, false, 3, 3, R"( HASH: 3B29F9557D2712A5BC982D218680927D 1 2 3 3 . X O @@ -87,7 +104,7 @@ HASH: 3B29F9557D2712A5BC982D218680927D 1 . . . )"); - checkStartPos("Cross on standard size", Rules::START_POS_CROSS, 39, 32, R"( + checkStartPos("Cross on standard size", Rules::START_POS_CROSS, false, 39, 32, R"( HASH: 516E1ABBA0D6B69A0B3D17C9E34E52F7 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -124,14 +141,14 @@ HASH: 516E1ABBA0D6B69A0B3D17C9E34E52F7 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . )"); - checkStartPos("Double cross on minimal size", Rules::START_POS_CROSS_2, 4, 2, R"( + checkStartPos("Double cross on minimal size", Rules::START_POS_CROSS_2, false, 4, 2, R"( HASH: 43FD769739F2AA27A8A1DAB1F4278229 1 2 3 4 2 X O O X 1 O X X O )"); - checkStartPos("Double cross on odd size", Rules::START_POS_CROSS_2, 5, 3, R"( + checkStartPos("Double cross on odd size", Rules::START_POS_CROSS_2, false, 5, 3, R"( HASH: AAA969B8135294A3D1ADAA07BEA9A987 1 2 3 4 5 3 . X O O X @@ -139,7 +156,7 @@ HASH: AAA969B8135294A3D1ADAA07BEA9A987 1 . . . . . )"); - checkStartPos("Double cross", Rules::START_POS_CROSS_2, 6, 4, R"( + checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 6, 4, R"( HASH: D599CEA39B1378D29883145CA4C016FC 1 2 3 4 5 6 4 . . . . . . @@ -148,7 +165,7 @@ HASH: D599CEA39B1378D29883145CA4C016FC 1 . . . . . . )"); - checkStartPos("Double cross", Rules::START_POS_CROSS_2, 7, 4, R"( + checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 7, 4, R"( HASH: 249F175819EA8FDE47F8676E655A06DE 1 2 3 4 5 6 7 4 . . . . . . . @@ -157,7 +174,7 @@ HASH: 249F175819EA8FDE47F8676E655A06DE 1 . . . . . . . )"); - checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, 39, 32, R"( + checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, false, 39, 32, R"( HASH: CAD72FD407955308CEFCBD7A9B14B35B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -194,7 +211,7 @@ HASH: CAD72FD407955308CEFCBD7A9B14B35B 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . )"); - checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 5, 5, R"( + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 5, 5, R"( HASH: 0C2DD637AAE5FA7E1469BF5829BE922B 1 2 3 4 5 5 X O . X O @@ -204,7 +221,7 @@ HASH: 0C2DD637AAE5FA7E1469BF5829BE922B 1 O X . O X )"); - checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 7, 7, R"( + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 7, 7, R"( HASH: 89CBCA85E94AF1B6C376E6BCBC443A48 1 2 3 4 5 6 7 7 . . . . . . . @@ -216,7 +233,7 @@ HASH: 89CBCA85E94AF1B6C376E6BCBC443A48 1 . . . . . . . )"); - checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, 8, 8, R"( + checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 8, 8, R"( HASH: 445D50D7A61C47CE2730BBB97A2B3C96 1 2 3 4 5 6 7 8 8 . . . . . . . . @@ -229,7 +246,7 @@ HASH: 445D50D7A61C47CE2730BBB97A2B3C96 1 . . . . . . . . )"); - checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, 39, 32, R"( + checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, false, 39, 32, R"( HASH: 2A9AE7F967F17B42D9B9CB45B735E9C6 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -265,4 +282,9 @@ HASH: 2A9AE7F967F17B42D9B9CB45B735E9C6 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . )"); + + checkStartPos("Random quadruple cross on standard size", Rules::START_POS_CROSS_4, true, 39, 32); + + checkGenerationAndRecognition(Rules::START_POS_CROSS_4, false); + checkGenerationAndRecognition(Rules::START_POS_CROSS_4, true); } \ No newline at end of file diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index 8b56fc3a9..df838f089 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -166,13 +166,13 @@ void runDotsStressTestsInternal( int gamesCount, bool dotsGame, int startPos, + bool startPosIsRandom, bool dotsCaptureEmptyBase, float komi, bool suicideAllowed, float groundingStartCoef, float groundingEndCoef, - bool performExtraChecks - ) { + bool performExtraChecks) { assert(groundingStartCoef >= 0 && groundingStartCoef <= 1); assert(groundingEndCoef >= 0 && groundingEndCoef <= 1); assert(groundingEndCoef >= groundingStartCoef); @@ -180,6 +180,7 @@ void runDotsStressTestsInternal( cout << " Random games" << endl; cout << " Game type: " << (dotsGame ? "Dots" : "Go") << endl; cout << " Start position: " << Rules::writeStartPosRule(startPos) << endl; + cout << " Start position is random: " << boolalpha << startPosIsRandom << endl; if (dotsGame) { cout << " Capture empty bases: " << boolalpha << dotsCaptureEmptyBase << endl; } @@ -198,23 +199,18 @@ void runDotsStressTestsInternal( Rand rand("runDotsStressTests"); - Rules rules = dotsGame ? Rules(dotsGame, startPos, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); - auto initialBoard = Board(x_size, y_size, rules); + Rules rules = dotsGame ? Rules(dotsGame, startPos, startPosIsRandom, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); + int numLegalMoves = x_size * y_size - rules.getNumOfStartPosStones(); vector randomMoves = vector(); - randomMoves.reserve(initialBoard.numLegalMoves); + randomMoves.reserve(numLegalMoves); - for(int y = 0; y < initialBoard.y_size; y++) { - for(int x = 0; x < initialBoard.x_size; x++) { - Loc loc = Location::getLoc(x, y, initialBoard.x_size); - if (initialBoard.getColor(loc) == C_EMPTY) { // Filter out initial poses - randomMoves.push_back(Location::getLoc(x, y, initialBoard.x_size)); - } + for(int y = 0; y < y_size; y++) { + for(int x = 0; x < x_size; x++) { + randomMoves.push_back(Location::getLoc(x, y, x_size)); } } - assert(randomMoves.size() == initialBoard.numLegalMoves); - int movesCount = 0; int blackWinsCount = 0; int whiteWinsCount = 0; @@ -227,11 +223,13 @@ void runDotsStressTestsInternal( rand.shuffle(randomMoves); moveRecords.clear(); - auto board = Board(initialBoard.x_size, initialBoard.y_size, rules); + auto initialBoard = Board(x_size, y_size, rules); + initialBoard.setStartPos(DOTS_RANDOM); + auto board = initialBoard; Loc lastLoc = Board::NULL_LOC; - int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * initialBoard.numLegalMoves; + int tryGroundingAfterMove = (groundingStartCoef + rand.nextDouble() * (groundingEndCoef - groundingStartCoef)) * numLegalMoves; Player pla = P_BLACK; int currentGameMovesCount = 0; for(short randomMove : randomMoves) { @@ -318,7 +316,7 @@ void Tests::runDotsStressTests() { cout << "Running dots stress tests" << endl; cout << " Max territory" << endl; - Board board = Board(39, 32, Rules::DEFAULT_DOTS); + auto board = Board(39, 32, Rules(true, Rules::START_POS_EMPTY, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; @@ -328,8 +326,9 @@ void Tests::runDotsStressTests() { testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); testAssert(0 == board.numLegalMoves); - runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, true); - runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, 0.5f, false, 0.8f, 1.0f, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, false, 0.0f, true, 0.8f, 1.0f, true); + runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, true, 0.5f, false, 0.8f, 1.0f, true); - runDotsStressTestsInternal(39, 32, 100000, true, Rules::START_POS_CROSS, false, 0.0f, true, 0.8f, 1.0f, false); + runDotsStressTestsInternal(39, 32, 50000, true, Rules::START_POS_CROSS, false, false, 0.0f, true, 0.8f, 1.0f, false); + runDotsStressTestsInternal(39, 32, 50000, true, Rules::START_POS_CROSS_4, true, false, 0.0f, true, 0.8f, 1.0f, false); } \ No newline at end of file diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp index 9dada066d..4fbd95d0d 100644 --- a/cpp/tests/testdotsutils.cpp +++ b/cpp/tests/testdotsutils.cpp @@ -3,10 +3,10 @@ using namespace std; Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { - return parseDotsField(input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); + return parseDotsField(input, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); } -Board parseDotsField(const string& input, const bool captureEmptyBases, +Board parseDotsField(const string& input, const bool startPosIsRandom, const bool captureEmptyBases, const bool freeCapturedDots, const vector& extraMoves) { int currentXSize = 0; int xSize = -1; @@ -27,7 +27,7 @@ Board parseDotsField(const string& input, const bool captureEmptyBases, } } - Board result = Board::parseBoard(xSize, ySize, input, Rules(true, Rules::START_POS_EMPTY, captureEmptyBases, freeCapturedDots)); + Board result = Board::parseBoard(xSize, ySize, input, Rules(true, Rules::START_POS_EMPTY, startPosIsRandom, captureEmptyBases, freeCapturedDots)); for(const XYMove& extraMove : extraMoves) { result.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, result.x_size), extraMove.player); } diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index 108bf824a..012ac7bc8 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -4,6 +4,8 @@ using namespace std; +inline Rand DOTS_RANDOM("DOTS_RANDOM"); + struct XYMove { int x; int y; @@ -62,5 +64,5 @@ struct BoardWithMoveRecords { Board parseDotsFieldDefault(const string& input, const vector& extraMoves = {}); -Board parseDotsField(const string& input, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); +Board parseDotsField(const string& input, bool startPosIsRandom, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); From fb78da9c08f46f61ebd919685168514c3f9d41b6 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:50:36 +0100 Subject: [PATCH 16/42] Tidy up code by using default parameters for some functions in the game dir --- cpp/command/genbook.cpp | 4 +- cpp/command/selfplay.cpp | 6 +- cpp/core/config_parser.cpp | 132 +++++--------------------------- cpp/core/config_parser.h | 30 +++----- cpp/dataio/sgf.cpp | 2 +- cpp/dataio/trainingwrite.cpp | 54 ++++++++----- cpp/dataio/trainingwrite.h | 4 +- cpp/game/board.cpp | 15 ++-- cpp/game/board.h | 8 +- cpp/game/boardhistory.cpp | 4 - cpp/game/boardhistory.h | 3 +- cpp/game/rules.cpp | 22 +----- cpp/game/rules.h | 10 +-- cpp/program/play.cpp | 6 +- cpp/search/asyncbot.cpp | 2 +- cpp/search/search.cpp | 102 ++++++++++++------------ cpp/search/search.h | 18 +---- cpp/tests/testboardbasic.cpp | 6 +- cpp/tests/testdotsbasic.cpp | 2 +- cpp/tests/testdotsextra.cpp | 12 +-- cpp/tests/testdotsstress.cpp | 2 +- cpp/tests/testrules.cpp | 2 +- cpp/tests/testsymmetries.cpp | 4 +- cpp/tests/testtrainingwrite.cpp | 44 +++++------ 24 files changed, 172 insertions(+), 322 deletions(-) diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index 079e65f98..e1eb5c630 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -444,8 +444,8 @@ int MainCmds::genbook(const vector& args) { if(!bonusInitialBoard.isEqualForTesting(book->getInitialHist().getRecentBoard(0), false, false)) throw StringError( "Book initial board and initial board in bonus sgf file do not match\n" + - Board::toStringSimple(book->getInitialHist().getRecentBoard(0),'\n') + "\n" + - Board::toStringSimple(bonusInitialBoard,'\n') + Board::toStringSimple(book->getInitialHist().getRecentBoard(0)) + "\n" + + Board::toStringSimple(bonusInitialBoard) ); if(bonusInitialPla != book->initialPla) throw StringError( diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index cc8521148..3cb66062a 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -214,9 +214,9 @@ int MainCmds::selfplay(const vector& args) { //Note that this inputsVersion passed here is NOT necessarily the same as the one used in the neural net self play, it //simply controls the input feature version for the written data - TrainingDataWriter* tdataWriter = new TrainingDataWriter( - tdataOutputDir, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); - ofstream* sgfOut = NULL; + auto tdataWriter = new TrainingDataWriter( + tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); + ofstream* sgfOut = nullptr; if(sgfOutputDir.length() > 0) { sgfOut = new ofstream(); FileUtils::open(*sgfOut, sgfOutputDir + "/" + Global::uint64ToHexString(rand.nextUInt64()) + ".sgfs"); diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index 3d8f60cf7..8cd36f60d 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -2,7 +2,6 @@ #include "../core/fileutils.h" -#include #include #include @@ -113,8 +112,6 @@ void ConfigParser::processIncludedFile(const std::string &fname) { baseDirs.pop_back(); } - - bool ConfigParser::parseKeyValue(const std::string& trimmedLine, std::string& key, std::string& value) { // Parse trimmed line, taking into account comments and quoting. key.clear(); @@ -569,7 +566,7 @@ vector ConfigParser::getStrings(const string& key, const set& po return values; } -bool ConfigParser::getBoolOrDefault(const std::string& key, bool defaultValue) { +bool ConfigParser::getBoolOrDefault(const std::string& key, const bool defaultValue) { if (contains(key)) { return getBool(key); } @@ -606,14 +603,7 @@ enabled_t ConfigParser::getEnabled(const string& key) { return x; } -int ConfigParser::getInt(const string& key) { - string value = getString(key); - int x; - if(!Global::tryStringToInt(value,x)) - throw IOError("Could not parse '" + value + "' as int for key '" + key + "' in config file " + fileName); - return x; -} -int ConfigParser::getInt(const string& key, int min, int max) { +int ConfigParser::getInt(const string& key, const int min, const int max) { assert(min <= max); string value = getString(key); int x; @@ -623,19 +613,8 @@ int ConfigParser::getInt(const string& key, int min, int max) { throw IOError("Key '" + key + "' must be in the range " + Global::intToString(min) + " to " + Global::intToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInts(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInts(const string& key, int min, int max) { + +vector ConfigParser::getInts(const string& key, const int min, const int max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getInts(const string& key, int min, int max) { } return ret; } -vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, int min, int max1, int max2) { + +vector> ConfigParser::getNonNegativeIntDashedPairs(const string& key, const int min, const int max1, const int max2) { std::vector pairStrs = getStrings(key); std::vector> ret; for(const string& pairStr: pairStrs) { @@ -678,15 +658,7 @@ vector> ConfigParser::getNonNegativeIntDashedPairs(const stri return ret; } - -int64_t ConfigParser::getInt64(const string& key) { - string value = getString(key); - int64_t x; - if(!Global::tryStringToInt64(value,x)) - throw IOError("Could not parse '" + value + "' as int64_t for key '" + key + "' in config file " + fileName); - return x; -} -int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { +int64_t ConfigParser::getInt64(const string& key, const int64_t min, const int64_t max) { assert(min <= max); string value = getString(key); int64_t x; @@ -696,19 +668,8 @@ int64_t ConfigParser::getInt64(const string& key, int64_t min, int64_t max) { throw IOError("Key '" + key + "' must be in the range " + Global::int64ToString(min) + " to " + Global::int64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t max) { + +vector ConfigParser::getInt64s(const string& key, const int64_t min, const int64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getInt64s(const string& key, int64_t min, int64_t return ret; } - -uint64_t ConfigParser::getUInt64(const string& key) { - string value = getString(key); - uint64_t x; - if(!Global::tryStringToUInt64(value,x)) - throw IOError("Could not parse '" + value + "' as uint64_t for key '" + key + "' in config file " + fileName); - return x; -} -uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) { +uint64_t ConfigParser::getUInt64(const string& key, const uint64_t min, const uint64_t max) { assert(min <= max); string value = getString(key); uint64_t x; @@ -741,19 +694,8 @@ uint64_t ConfigParser::getUInt64(const string& key, uint64_t min, uint64_t max) throw IOError("Key '" + key + "' must be in the range " + Global::uint64ToString(min) + " to " + Global::uint64ToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getUInt64s(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint64_t max) { + +vector ConfigParser::getUInt64s(const string& key, const uint64_t min, const uint64_t max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getUInt64s(const string& key, uint64_t min, uint6 return ret; } - -float ConfigParser::getFloat(const string& key) { - string value = getString(key); - float x; - if(!Global::tryStringToFloat(value,x)) - throw IOError("Could not parse '" + value + "' as float for key '" + key + "' in config file " + fileName); - return x; -} -float ConfigParser::getFloat(const string& key, float min, float max) { +float ConfigParser::getFloat(const string& key, const float min, const float max) { assert(min <= max); string value = getString(key); float x; @@ -788,19 +722,8 @@ float ConfigParser::getFloat(const string& key, float min, float max) { throw IOError("Key '" + key + "' must be in the range " + Global::floatToString(min) + " to " + Global::floatToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getFloats(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { + +vector ConfigParser::getFloats(const string& key, const float min, const float max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i ConfigParser::getFloats(const string& key, float min, float max) { return ret; } - -double ConfigParser::getDouble(const string& key) { - string value = getString(key); - double x; - if(!Global::tryStringToDouble(value,x)) - throw IOError("Could not parse '" + value + "' as double for key '" + key + "' in config file " + fileName); - return x; -} -double ConfigParser::getDouble(const string& key, double min, double max) { +double ConfigParser::getDouble(const string& key, const double min, const double max) { assert(min <= max); string value = getString(key); double x; @@ -837,19 +752,8 @@ double ConfigParser::getDouble(const string& key, double min, double max) { throw IOError("Key '" + key + "' must be in the range " + Global::doubleToString(min) + " to " + Global::doubleToString(max) + " in config file " + fileName); return x; } -vector ConfigParser::getDoubles(const string& key) { - vector values = getStrings(key); - vector ret; - for(size_t i = 0; i ConfigParser::getDoubles(const string& key, double min, double max) { + +vector ConfigParser::getDoubles(const string& key, const double min, const double max) { vector values = getStrings(key); vector ret; for(size_t i = 0; i& possibles); - int getInt(const std::string& key, int min, int max); - int64_t getInt64(const std::string& key, int64_t min, int64_t max); - uint64_t getUInt64(const std::string& key, uint64_t min, uint64_t max); - float getFloat(const std::string& key, float min, float max); - double getDouble(const std::string& key, double min, double max); + int getInt(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + int64_t getInt64(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + uint64_t getUInt64(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + float getFloat(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + double getDouble(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); std::vector getStrings(const std::string& key); std::vector getStringsNonEmptyTrim(const std::string& key); std::vector getBools(const std::string& key); - std::vector getInts(const std::string& key); - std::vector getInt64s(const std::string& key); - std::vector getUInt64s(const std::string& key); - std::vector getFloats(const std::string& key); - std::vector getDoubles(const std::string& key); std::vector getStrings(const std::string& key, const std::set& possibles); - std::vector getInts(const std::string& key, int min, int max); - std::vector getInt64s(const std::string& key, int64_t min, int64_t max); - std::vector getUInt64s(const std::string& key, uint64_t min, uint64_t max); - std::vector getFloats(const std::string& key, float min, float max); - std::vector getDoubles(const std::string& key, double min, double max); + std::vector getInts(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); + std::vector getInt64s(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); + std::vector getUInt64s(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); + std::vector getFloats(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); + std::vector getDoubles(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); std::vector> getNonNegativeIntDashedPairs(const std::string& key, int min, int max1, int max2); diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 5de990129..faf5853be 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -2014,7 +2014,7 @@ void WriteSgf::writeSgf( } out << "KM[" << rules.komi << "]"; - out << "RU[" << (tryNicerRulesString ? rules.toStringNoSgfDefinedPropertiesMaybeNice() : rules.toStringNoSgfDefinedProps()) << "]"; + out << "RU[" << (tryNicerRulesString ? rules.toStringNoSgfDefinedPropertiesMaybeNice() : rules.toString(false)) << "]"; printGameResult(out,endHist,overrideFinishedWhiteScore); bool hasAB = false; diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index 5114177ac..823e4337b 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -36,7 +36,7 @@ SidePosition::SidePosition(const Rules& rules) playoutDoublingAdvantage(0.0) {} -SidePosition::SidePosition(const Board& b, const BoardHistory& h, Player p, int numNNChangesSoFar) +SidePosition::SidePosition(const Board& b, const BoardHistory& h, const Player p, const int numNNChangesSoFar) :board(b), hist(h), pla(p), @@ -265,7 +265,9 @@ static const int GLOBAL_TARGET_NUM_CHANNELS = 64; static const int VALUE_SPATIAL_TARGET_NUM_CHANNELS = 5; static const int QVALUE_SPATIAL_TARGET_NUM_CHANNELS = 3; -TrainingWriteBuffers::TrainingWriteBuffers(int iVersion, int maxRws, int numBChannels, int numFChannels, int xLen, int yLen, bool includeMetadata) +TrainingWriteBuffers::TrainingWriteBuffers( + const int iVersion, int maxRws, int numBChannels, int numFChannels, int xLen, int yLen, + const bool includeMetadata) :inputsVersion(iVersion), maxRows(maxRws), numBinaryChannels(numBChannels), @@ -298,7 +300,7 @@ void TrainingWriteBuffers::clear() { } //Copy floats that are all 0-1 into bits, packing 8 to a byte, big-endian-style within each byte. -static void packBits(const float* binaryFloats, int len, uint8_t* bits) { +static void packBits(const float* binaryFloats, const int len, uint8_t* bits) { for(int i = 0; i < len; i += 8) { if(i + 8 <= len) { bits[i >> 3] = @@ -320,18 +322,22 @@ static void packBits(const float* binaryFloats, int len, uint8_t* bits) { } } -static void zeroPolicyTarget(int policySize, int16_t* target) { +static void zeroPolicyTarget(const int policySize, int16_t* target) { for(int pos = 0; pos& policyTargetMoves, int policySize, int dataXLen, int dataYLen, int boardXSize, int16_t* target) { +static void fillPolicyTarget(const vector& policyTargetMoves, + const int policySize, + const int dataXLen, + const int dataYLen, + const int boardXSize, int16_t* target) { zeroPolicyTarget(policySize,target); size_t size = policyTargetMoves.size(); for(size_t i = 0; i& policyTargetMoves, //Clamps a value to integer in [-120,120] to pack down to 8 bits. //Randomizes to make sure the expectation is exactly correct. -static int8_t clampToRadius120(float x, Rand& rand) { +static int8_t clampToRadius120(const float x, Rand& rand) { //We need to pack this down to 8 bits, so map into [-120,120]. //Randomize to ensure the expectation is exactly correct. int low = (int)floor(x); @@ -356,7 +362,7 @@ static int8_t clampToRadius120(float x, Rand& rand) { if(lambda == 0.0f) return (int8_t)low; else return (int8_t)(rand.nextBool(lambda) ? high : low); } -static int16_t clampToRadius32000(float x, Rand& rand) { +static int16_t clampToRadius32000(const float x, Rand& rand) { //We need to pack this down to 16 bits, so clamp into an integer [-32000,32000]. //Randomize to ensure the expectation is exactly correct. int low = (int)floor(x); @@ -369,7 +375,12 @@ static int16_t clampToRadius32000(float x, Rand& rand) { else return (int16_t)(rand.nextBool(lambda) ? high : low); } -static void fillQValueTarget(const vector& whiteQValueTargets, Player nextPlayer, int policySize, int dataXLen, int dataYLen, int boardXSize, int16_t* cPosTarget, Rand& rand) { +static void fillQValueTarget(const vector& whiteQValueTargets, + const Player nextPlayer, + const int policySize, + const int dataXLen, + const int dataYLen, + const int boardXSize, int16_t* cPosTarget, Rand& rand) { for(int i = 0; i < QVALUE_SPATIAL_TARGET_NUM_CHANNELS * policySize; i++) { cPosTarget[i] = 0; } @@ -395,7 +406,10 @@ static void fillQValueTarget(const vector& whiteQValueTargets, } } -static void fillValueTDTargets(const vector& whiteValueTargetsByTurn, int idx, Player nextPlayer, double nowFactor, float* buf) { +static void fillValueTDTargets(const vector& whiteValueTargetsByTurn, + const int idx, + const Player nextPlayer, + const double nowFactor, float* buf) { double winValue = 0.0; double lossValue = 0.0; double noResultValue = 0.0; @@ -939,22 +953,22 @@ void TrainingWriteBuffers::writeToTextOstream(ostream& out) { //------------------------------------------------------------------------------------- -TrainingDataWriter::TrainingDataWriter(const string& outDir, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const string& randSeed) - : TrainingDataWriter(outDir,NULL,iVersion,maxRowsPerFile,firstFileMinRandProp,dataXLen,dataYLen,1,randSeed) -{} -TrainingDataWriter::TrainingDataWriter(ostream* dbgOut, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyEvery, const string& randSeed) - : TrainingDataWriter(string(),dbgOut,iVersion,maxRowsPerFile,firstFileMinRandProp,dataXLen,dataYLen,onlyEvery,randSeed) -{} - -TrainingDataWriter::TrainingDataWriter(const string& outDir, ostream* dbgOut, int iVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyEvery, const string& randSeed) - :outputDir(outDir),inputsVersion(iVersion),rand(randSeed),writeBuffers(NULL),debugOut(dbgOut),debugOnlyWriteEvery(onlyEvery),rowCount(0) +TrainingDataWriter::TrainingDataWriter(const string& outputDir, ostream* debugOut, + const int inputsVersion, + const int maxRowsPerFile, + const double firstFileMinRandProp, + const int dataXLen, + const int dataYLen, + const string& randSeed, + const int onlyWriteEvery) + :outputDir(outputDir),inputsVersion(inputsVersion),rand(randSeed),writeBuffers(nullptr),debugOut(debugOut),debugOnlyWriteEvery(onlyWriteEvery),rowCount(0) { //Note that this inputsVersion is for data writing, it might be different than the inputsVersion used // to feed into a model during selfplay const int numBinaryChannels = NNInputs::getNumberOfSpatialFeatures(inputsVersion); const int numGlobalChannels = NNInputs::getNumberOfGlobalFeatures(inputsVersion); - const bool hasMetadataInput = false; + constexpr bool hasMetadataInput = false; writeBuffers = new TrainingWriteBuffers( inputsVersion, maxRowsPerFile, diff --git a/cpp/dataio/trainingwrite.h b/cpp/dataio/trainingwrite.h index 23d043603..2034032fd 100644 --- a/cpp/dataio/trainingwrite.h +++ b/cpp/dataio/trainingwrite.h @@ -311,9 +311,7 @@ struct TrainingWriteBuffers { class TrainingDataWriter { public: - TrainingDataWriter(const std::string& outputDir, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed); - TrainingDataWriter(std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyWriteEvery, const std::string& randSeed); - TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, int onlyWriteEvery, const std::string& randSeed); + TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1); ~TrainingDataWriter(); void writeGame(const FinishedGameData& data); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 4e3ad4f93..7b3a5c780 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -2483,11 +2483,9 @@ void Board::checkConsistency() const { } } -bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const { - return isEqualForTesting(other, checkNumCaptures, checkSimpleKo, true); -} - -bool Board::isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo, bool checkRules) const { +bool Board::isEqualForTesting(const Board& other, const bool checkNumCaptures, + const bool checkSimpleKo, + const bool checkRules) const { checkConsistency(); other.checkConsistency(); if(x_size != other.x_size) @@ -2830,11 +2828,8 @@ string Board::toStringSimple(const Board& board, char lineDelimiter) { return s; } -Board Board::parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter) { - return parseBoard(xSize, ySize, s, Rules::DEFAULT_GO, lineDelimiter); -} - -Board Board::parseBoard(int xSize, int ySize, const string& s, const Rules& rules, char lineDelimiter) { +Board Board::parseBoard(const int xSize, + const int ySize, const string& s, const Rules& rules, const char lineDelimiter) { Board board(xSize,ySize,rules); vector lines = Global::split(Global::trim(s),lineDelimiter); diff --git a/cpp/game/board.h b/cpp/game/board.h index f162a89bf..67ce4c640 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -389,14 +389,12 @@ struct Board void checkConsistency() const; //For the moment, only used in testing since it does extra consistency checks. //If we need a version to be used in "prod", we could make an efficient version maybe as operator==. - bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo) const; - bool isEqualForTesting(const Board& other, bool checkNumCaptures, bool checkSimpleKo, bool checkRules) const; + bool isEqualForTesting(const Board& other, bool checkNumCaptures = true, bool checkSimpleKo = true, bool checkRules = true) const; - static Board parseBoard(int xSize, int ySize, const std::string& s, char lineDelimiter = '\n'); - static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules, char lineDelimiter = '\n'); + static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules = Rules::DEFAULT_GO, char lineDelimiter = '\n'); std::string toString() const; static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); - static std::string toStringSimple(const Board& board, char lineDelimiter); + static std::string toStringSimple(const Board& board, char lineDelimiter = '\n'); static nlohmann::json toJson(const Board& board); static Board ofJson(const nlohmann::json& data); diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 5fb435b60..c9eafe9b1 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -1007,10 +1007,6 @@ bool BoardHistory::makeBoardMoveTolerant(Board& board, Loc moveLoc, Player moveP return true; } -void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable) { - makeBoardMoveAssumeLegal(board,moveLoc,movePla,rootKoHashTable,false); -} - void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore) { Hash128 posHashBeforeMove = board.pos_hash; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 108903550..a3ab88d99 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -162,8 +162,7 @@ struct BoardHistory { //even if the move violates superko or encore ko recapture prohibitions, or is past when the game is ended. //This allows for robustness when this code is being used for analysis or with external data sources. //preventEncore artifically prevents any move from entering or advancing the encore phase when using territory scoring. - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable); - void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore); + void makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player movePla, const KoHashTable* rootKoHashTable, bool preventEncore = false); //Make a move with legality checking, but be mostly tolerant and allow moves that can still be handled but that may not technically //be legal. This is intended for reading moves from SGFs and such where maybe we're getting moves that were played in a different //ruleset than ours. Returns true if successful, false if was illegal even unter tolerant rules. diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index 8f3da0d06..c17ca0d9e 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -265,14 +265,6 @@ ostream& operator<<(ostream& out, const Rules& rules) { return out; } -string Rules::toString() const { - return toString(true); -} - -string Rules::toStringNoSgfDefinedProps() const { - return toString(false); -} - string Rules::toString(const bool includeSgfDefinedProperties) const { ostringstream out; if (!isDots) { @@ -668,19 +660,11 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) return rules; } -Rules Rules::parseRules(const string& sOrig) { - return parseRules(sOrig,false); -} - -Rules Rules::parseRules(const string& sOrig, bool isDots) { +Rules Rules::parseRules(const string& sOrig, const bool isDots) { return parseRulesHelper(sOrig,true,isDots); } -Rules Rules::parseRulesWithoutKomi(const string& sOrig, float komi) { - return parseRulesWithoutKomi(sOrig,komi,false); -} - -Rules Rules::parseRulesWithoutKomi(const string& sOrig, float komi, bool isDots) { +Rules Rules::parseRulesWithoutKomi(const string& sOrig, const float komi, const bool isDots) { Rules rules = parseRulesHelper(sOrig,false,isDots); rules.komi = komi; return rules; @@ -720,7 +704,7 @@ string Rules::toStringNoSgfDefinedPropertiesMaybeNice() const { return "StoneScoring"; if(equalsIgnoringSgfDefinedProps(parseRulesHelper("NewZealand",false, isDots))) return "NewZealand"; - return toStringNoSgfDefinedProps(); + return toString(false); } double nextRandomOffset(Rand& rand) { diff --git a/cpp/game/rules.h b/cpp/game/rules.h index f55fdaff6..f22fef300 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -104,10 +104,8 @@ struct Rules { static std::set startPosStrings(); int getNumOfStartPosStones() const; - static Rules parseRules(const std::string& sOrig); - static Rules parseRules(const std::string& sOrig, bool isDots); - static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi); - static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi, bool isDots); + static Rules parseRules(const std::string& sOrig, bool isDots = false); + static Rules parseRulesWithoutKomi(const std::string& sOrig, float komi, bool isDots = false); static bool tryParseRules(const std::string& sOrig, Rules& buf, bool isDots); static bool tryParseRulesWithoutKomi(const std::string& sOrig, Rules& buf, float komi, bool isDots); @@ -129,9 +127,7 @@ struct Rules { bool& randomized); friend std::ostream& operator<<(std::ostream& out, const Rules& rules); - std::string toString() const; - std::string toStringNoSgfDefinedProps() const; - std::string toString(bool includeSgfDefinedProperties) const; + std::string toString(bool includeSgfDefinedProperties = true) const; std::string toStringNoSgfDefinedPropertiesMaybeNice() const; std::string toJsonString() const; std::string toJsonStringNoKomi() const; diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index e8e15626c..1e240f2d7 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -2471,12 +2471,12 @@ FinishedGameData* GameRunner::runGame( Search* botB; Search* botW; if(botSpecB.botIdx == botSpecW.botIdx) { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, hist.rules); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed, nullptr, hist.rules); botW = botB; } else { - botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", hist.rules); - botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", hist.rules); + botB = new Search(botSpecB.baseParams, botSpecB.nnEval, &logger, seed + "@B", nullptr, hist.rules); + botW = new Search(botSpecW.baseParams, botSpecW.nnEval, &logger, seed + "@W", nullptr, hist.rules); } if(afterInitialization != nullptr) { if(botSpecB.botIdx == botSpecW.botIdx) { diff --git a/cpp/search/asyncbot.cpp b/cpp/search/asyncbot.cpp index 3e383f4aa..393a2cf3f 100644 --- a/cpp/search/asyncbot.cpp +++ b/cpp/search/asyncbot.cpp @@ -55,7 +55,7 @@ AsyncBot::AsyncBot( analyzeCallback(), searchBegunCallback() { - search = new Search(params,nnEval,humanEval,l,randSeed,rules); + search = new Search(params,nnEval,l,randSeed,humanEval,rules); searchThread = std::thread(searchThreadLoop,this,l); } diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index afe26c5c4..a85147b95 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -65,55 +65,49 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; -Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed) - :Search(params,nnEval,NULL,lg,rSeed,Rules::DEFAULT_GO) -{} -Search::Search(SearchParams params, NNEvaluator* nnEval, Logger* lg, const string& rSeed, const Rules& rules) - :Search(params,nnEval,NULL,lg,rSeed,rules) -{} -Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, Logger* lg, const string& rSeed, const Rules& rules) - :rootPla(P_BLACK), - rootBoard(rules), - rootHistory(rules), - rootGraphHash(), - rootHintLoc(Board::NULL_LOC), - avoidMoveUntilByLocBlack(),avoidMoveUntilByLocWhite(),avoidMoveUntilRescaleRoot(false), - rootSymmetries(), - rootPruneOnlySymmetries(), - rootSafeArea(NULL), - recentScoreCenter(0.0), - mirroringPla(C_EMPTY), - mirrorAdvantage(0.0), - mirrorCenterSymmetryError(1e10), - alwaysIncludeOwnerMap(false), - searchParams(params),numSearchesBegun(0),searchNodeAge(0), - plaThatSearchIsFor(C_EMPTY),plaThatSearchIsForLastSearch(C_EMPTY), - lastSearchNumPlayouts(0), - effectiveSearchTimeCarriedOver(0.0), - randSeed(rSeed), - rootKoHashTable(NULL), - valueWeightDistribution(NULL), - patternBonusTable(NULL), - externalPatternBonusTable(nullptr), - evalCache(nullptr), - nonSearchRand(rSeed + string("$nonSearchRand")), - logger(lg), - nnEvaluator(nnEval), - humanEvaluator(humanEval), - nnXLen(), - nnYLen(), - policySize(), - rootNode(NULL), - nodeTable(NULL), - mutexPool(NULL), - subtreeValueBiasTable(NULL), - numThreadsSpawned(0), - threads(NULL), - threadTasks(NULL), - threadTasksRemaining(NULL), - oldNNOutputsToCleanUpMutex(), - oldNNOutputsToCleanUp() -{ +Search::Search(const SearchParams ¶ms, NNEvaluator* nnEval, Logger* lg, const string& randSeed, NNEvaluator* humanEval, const Rules& rules) + : rootPla(P_BLACK), + rootBoard(rules), + rootHistory(rules), + rootGraphHash(), + rootHintLoc(Board::NULL_LOC), + avoidMoveUntilByLocBlack(), avoidMoveUntilByLocWhite(), avoidMoveUntilRescaleRoot(false), + rootSymmetries(), + rootPruneOnlySymmetries(), + rootSafeArea(NULL), + recentScoreCenter(0.0), + mirroringPla(C_EMPTY), + mirrorAdvantage(0.0), + mirrorCenterSymmetryError(1e10), + alwaysIncludeOwnerMap(false), + searchParams(params), numSearchesBegun(0), searchNodeAge(0), + plaThatSearchIsFor(C_EMPTY), plaThatSearchIsForLastSearch(C_EMPTY), + lastSearchNumPlayouts(0), + effectiveSearchTimeCarriedOver(0.0), + randSeed(randSeed), + rootKoHashTable(NULL), + valueWeightDistribution(NULL), + normToTApproxZ(0), + patternBonusTable(NULL), + externalPatternBonusTable(nullptr), + evalCache(nullptr), + nonSearchRand(randSeed + string("$nonSearchRand")), + logger(lg), + nnEvaluator(nnEval), + humanEvaluator(humanEval), + nnXLen(), + nnYLen(), + policySize(), + rootNode(NULL), + nodeTable(NULL), + mutexPool(NULL), + subtreeValueBiasTable(NULL), + numThreadsSpawned(0), + threads(NULL), + threadTasks(NULL), + threadTasksRemaining(NULL), + oldNNOutputsToCleanUpMutex(), + oldNNOutputsToCleanUp() { assert(logger != NULL); nnXLen = nnEval->getNNXLen(); nnYLen = nnEval->getNNYLen(); @@ -121,13 +115,13 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, assert(nnYLen > 0 && nnYLen <= NNPos::MAX_BOARD_LEN_Y); policySize = NNPos::getPolicySize(nnXLen, nnYLen); - if(humanEvaluator != NULL) { - if(humanEvaluator->getNNXLen() != nnXLen || humanEvaluator->getNNYLen() != nnYLen) + if (humanEvaluator != NULL) { + if (humanEvaluator->getNNXLen() != nnXLen || humanEvaluator->getNNYLen() != nnYLen) throw StringError("Search::init - humanEval has different nnXLen or nnYLen"); } assert(rootHistory.rules.isDots == rootBoard.isDots()); - rootHistory.clear(rootBoard,rootPla,rules,0); + rootHistory.clear(rootBoard, rootPla, rules, 0); if (!rules.isDots) { rootKoHashTable = new KoHashTable(); rootKoHashTable->recompute(rootHistory); @@ -136,8 +130,8 @@ Search::Search(SearchParams params, NNEvaluator* nnEval, NNEvaluator* humanEval, rootSafeArea = new Color[Board::MAX_ARR_SIZE]; valueWeightDistribution = new DistributionTable( - [](double z) { return FancyMath::tdistpdf(z,VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, - [](double z) { return FancyMath::tdistcdf(z,VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, + [](double z) { return FancyMath::tdistpdf(z, VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, + [](double z) { return FancyMath::tdistcdf(z, VALUE_WEIGHT_DEGREES_OF_FREEDOM); }, -50.0, 50.0, 2000 diff --git a/cpp/search/search.h b/cpp/search/search.h index 901a19506..f59511c34 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -182,24 +182,12 @@ struct Search { //Note - randSeed controls a few things in the search, but a lot of the randomness actually comes from //random symmetries of the neural net evaluations, see nneval.h Search( - SearchParams params, - NNEvaluator* nnEval, - Logger* logger, - const std::string& randSeed - ); - Search( - SearchParams params, + const SearchParams ¶ms, NNEvaluator* nnEval, Logger* logger, const std::string& randSeed, - const Rules& rules); - Search( - SearchParams params, - NNEvaluator* nnEval, - NNEvaluator* humanEval, - Logger* lg, - const std::string& rSeed, - const Rules& rules); + NNEvaluator* humanEval = nullptr, + const Rules& rules = Rules::DEFAULT_GO); ~Search(); Search(const Search&) = delete; diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index 172c19299..68b88a516 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -2519,9 +2519,9 @@ oxxxxx.xo //if(rep < 100) // hist.printDebugInfo(cout,board); - testAssert(boardCopy.isEqualForTesting(board, true, true)); - testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0), true, true)); - testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0), true, true)); + testAssert(boardCopy.isEqualForTesting(board)); + testAssert(boardCopy.isEqualForTesting(histCopy.getRecentBoard(0))); + testAssert(histCopy.getRecentBoard(0).isEqualForTesting(hist.getRecentBoard(0))); testAssert(BoardHistory::getSituationRulesAndKoHash(boardCopy,histCopy,pla,drawEquivalentWinsForWhite) == hist.getSituationRulesAndKoHash(board,hist,pla,drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_BLACK, drawEquivalentWinsForWhite)); testAssert(histCopy.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite) == hist.currentSelfKomi(P_WHITE, drawEquivalentWinsForWhite)); diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index 0aec8aea4..839ef5f82 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -23,7 +23,7 @@ void checkDotsField(const string& description, const string& input, bool capture board.undo(moveRecords.back()); moveRecords.pop_back(); } - testAssert(initialBoard.isEqualForTesting(board, true, true)); + testAssert(initialBoard.isEqualForTesting(board)); } void checkDotsFieldDefault(const string& description, const string& input, const std::function& check) { diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp index 3ad341be5..c68b803b4 100644 --- a/cpp/tests/testdotsextra.cpp +++ b/cpp/tests/testdotsextra.cpp @@ -13,8 +13,8 @@ void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardIn for (const XYMove& extraMove : extraMoves) { expectedBoard.playMoveAssumeLegal(SymmetryHelpers::getSymLoc(extraMove.x, extraMove.y, initBoard, symmetry), extraMove.player); } - expect(SymmetryHelpers::symmetryToString(symmetry).c_str(), Board::toStringSimple(transformedBoard, '\n'), Board::toStringSimple(expectedBoard, '\n')); - testAssert(transformedBoard.isEqualForTesting(expectedBoard, true, true, true)); + expect(SymmetryHelpers::symmetryToString(symmetry).c_str(), Board::toStringSimple(transformedBoard), Board::toStringSimple(expectedBoard)); + testAssert(transformedBoard.isEqualForTesting(expectedBoard)); } void Tests::runDotsSymmetryTests() { @@ -119,11 +119,11 @@ SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 2, expectedBoard.x_size), P_BLACK, true); expectedBoard.playMoveAssumeLegal(Location::getLoc(1, 1, expectedBoard.x_size), P_BLACK); - expect("Dots symmetry with start pos", Board::toStringSimple(rotatedBoard, '\n'), Board::toStringSimple(expectedBoard, '\n')); - testAssert(rotatedBoard.isEqualForTesting(expectedBoard, true, true, true)); + expect("Dots symmetry with start pos", Board::toStringSimple(rotatedBoard), Board::toStringSimple(expectedBoard)); + testAssert(rotatedBoard.isEqualForTesting(expectedBoard)); const auto unrotatedBoard = SymmetryHelpers::getSymBoard(rotatedBoard, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y); - testAssert(board.isEqualForTesting(unrotatedBoard, true, true, true)); + testAssert(board.isEqualForTesting(unrotatedBoard)); } string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { @@ -288,7 +288,7 @@ std::pair getCapturingAndBases( } // Make sure we didn't change an internal state during calculating - testAssert(board.isEqualForTesting(copy, true, true)); + testAssert(board.isEqualForTesting(copy)); return {capturesStringStream.str(), basesStringStream.str()}; } diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index df838f089..418c096fb 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -293,7 +293,7 @@ void runDotsStressTestsInternal( moveRecords.pop_back(); } - testAssert(initialBoard.isEqualForTesting(board, true, false)); + testAssert(initialBoard.isEqualForTesting(board)); } } diff --git a/cpp/tests/testrules.cpp b/cpp/tests/testrules.cpp index 58ae09228..cfff353ce 100644 --- a/cpp/tests/testrules.cpp +++ b/cpp/tests/testrules.cpp @@ -5401,7 +5401,7 @@ Last moves pass pass pass pass H7 G9 F9 H7 testAssert(rules[i] == parsed); Rules parsed2; - suc = Rules::tryParseRulesWithoutKomi(rules[i].toStringNoSgfDefinedProps(), parsed2, rules[i].komi, rules[i].isDots); + suc = Rules::tryParseRulesWithoutKomi(rules[i].toString(false), parsed2, rules[i].komi, rules[i].isDots); testAssert(suc); testAssert(rules[i] == parsed2); diff --git a/cpp/tests/testsymmetries.cpp b/cpp/tests/testsymmetries.cpp index 5703ccf35..f686f9269 100644 --- a/cpp/tests/testsymmetries.cpp +++ b/cpp/tests/testsymmetries.cpp @@ -413,7 +413,7 @@ x.xxo.... Loc symLocComb = SymmetryHelpers::getSymLoc(loc,board,symmetryComposed); Loc symLocCombManual = SymmetryHelpers::getSymLoc(SymmetryHelpers::getSymLoc(loc,board,symmetry1),SymmetryHelpers::getSymBoard(board,symmetry1),symmetry2); out << "Symmetry " << symmetry1 << " + " << symmetry2 << " = " << symmetryComposed << endl; - testAssert(symBoardCombManual.isEqualForTesting(symBoardComb,true,true)); + testAssert(symBoardCombManual.isEqualForTesting(symBoardComb)); testAssert(symLocComb == symLocCombManual); } } @@ -588,7 +588,7 @@ x.xxo.... out << "SYMMETRY " << symmetry << endl; out << boardA << endl; out << boardB << endl; - testAssert(boardA.isEqualForTesting(boardB,true,true)); + testAssert(boardA.isEqualForTesting(boardB)); } string expected = R"%%( SYMMETRY 0 diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index d73e5b7ba..342d8bf8a 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -9,6 +9,15 @@ using namespace std; using namespace TestCommon; +TrainingDataWriter createTestTrainingDataWriter( + const int inputVersion, + const int nnXLen, + const int nnYLen, + const string& seed, + const int onlyWriteEvery) { + return TrainingDataWriter(string(), &cout, inputVersion, 256, 1.0f, nnXLen, nnYLen, seed, onlyWriteEvery); +} + static NNEvaluator* startNNEval( const string& modelFile, const string& seed, Logger& logger, int defaultSymmetry, bool inputsUseNHWC, bool useNHWC, bool useFP16 @@ -67,13 +76,9 @@ void Tests::runTrainingWriteTests() { cout << "Running training write tests" << endl; NeuralNet::globalInitialize(); - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 5; - - const bool logToStdout = true; - const bool logToStderr = false; - const bool logTime = false; + constexpr bool logToStdout = true; + constexpr bool logToStderr = false; + constexpr bool logTime = false; Logger logger(nullptr, logToStdout, logToStderr, logTime); auto run = [&]( @@ -83,7 +88,7 @@ void Tests::runTrainingWriteTests() { int boardXLen, int boardYLen, bool cheapLongSgf ) { - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, nnXLen, nnYLen, seedBase + "dwriter", 5); NNEvaluator* nnEval = startNNEval("/dev/null",seedBase+"nneval",logger,0,inputsNHWC,useNHWC,false); @@ -1056,8 +1061,6 @@ xxxxxxxx. cout << "====================================================================================================" << endl; cout << "Testing turnnumber and early temperatures" << endl; - int maxRows = 256; - double firstFileMinRandProp = 1.0; int debugOnlyWriteEvery = 1; int inputsVersion = 7; @@ -1142,7 +1145,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board(9,9,rules); @@ -1173,7 +1176,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board(9,9,rules); @@ -1203,9 +1206,6 @@ xxxxxxxx. cout << "====================================================================================================" << endl; cout << "Testing no result" << endl; - int maxRows = 256; - double firstFileMinRandProp = 1.0; - int debugOnlyWriteEvery = 1; int inputsVersion = 7; SearchParams params = SearchParams::forTestsV2(); @@ -1272,9 +1272,9 @@ xxxxxxxx. botSpec.nnEval = nnEval; botSpec.baseParams = params; - string seed = "seed-testing-temperature"; - { + string seed = "seed-testing-temperature"; + int debugOnlyWriteEvery = 1; cout << "Turn number initial 0 selfplay with high temperatures" << endl; nnEval->clearCache(); nnEval->clearStats(); @@ -1284,7 +1284,7 @@ xxxxxxxx. GameRunner* gameRunner = new GameRunner(cfg, seed, playSettings, logger); auto shouldStop = []() noexcept { return false; }; WaitableFlag* shouldPause = nullptr; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, 9, 9, debugOnlyWriteEvery, seed); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 9, 9, seed, debugOnlyWriteEvery); Sgf::PositionSample startPosSample; startPosSample.board = Board::parseBoard(9,9,R"%%( @@ -2723,7 +2723,6 @@ oox.x.... NeuralNet::globalCleanup(); } - void Tests::runSekiTrainWriteTests(const string& modelFile) { bool inputsNHWC = true; bool useNHWC = false; @@ -2732,9 +2731,6 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { cout << "Running test for how a seki gets recorded" << endl; NeuralNet::globalInitialize(); - int nnXLen = 13; - int nnYLen = 13; - const bool logToStdout = true; const bool logToStderr = false; const bool logTime = false; @@ -2744,10 +2740,8 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { auto run = [&](const string& sgfStr, const string& seedBase, const Rules& rules) { int inputsVersion = 6; - int maxRows = 256; - double firstFileMinRandProp = 1.0; int debugOnlyWriteEvery = 1000; - TrainingDataWriter dataWriter(&cout,inputsVersion, maxRows, firstFileMinRandProp, nnXLen, nnYLen, debugOnlyWriteEvery, seedBase+"dwriter"); + TrainingDataWriter dataWriter = createTestTrainingDataWriter(inputsVersion, 13, 13, seedBase + "dwriter", debugOnlyWriteEvery); nnEval->clearCache(); nnEval->clearStats(); From b8b798e81b8cb2de5881135be9ead51397384b18 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 15:59:27 +0100 Subject: [PATCH 17/42] Refine `tryRecognizeStartPos` to ignore handicap and to detect remaining placement moves --- cpp/dataio/sgf.cpp | 6 ++-- cpp/game/rules.cpp | 38 ++++++++++++++++++++--- cpp/game/rules.h | 9 ++++-- cpp/neuralnet/nninputs.cpp | 2 +- cpp/tests/testdotsstartposes.cpp | 4 +-- cpp/tests/testsgf.cpp | 52 +++++++++++++++++++------------- 6 files changed, 77 insertions(+), 34 deletions(-) diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index faf5853be..53de52477 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -152,7 +152,8 @@ static Rules getRulesFromSgf(const SgfNode& rootNode, const int xSize, const int vector placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); bool randomized; - rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, true, randomized); + vector remainingMoves; + rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, randomized, &remainingMoves); if (randomized && !rules.startPosIsRandom) { propertyFail("Defined start pos is randomized but RU says it shouldn't"); } @@ -1774,7 +1775,8 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); bool randomized; - rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, true, randomized); + vector remainingMoves; + rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, randomized, &remainingMoves); if (randomized && !rules.startPosIsRandom) { f("Defined start pos is randomized but RU says it shouldn't"); } diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index c17ca0d9e..af5a35b8f 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -159,7 +159,8 @@ set Rules::startPosStrings() { startPosIdToName[START_POS_EMPTY], startPosIdToName[START_POS_CROSS], startPosIdToName[START_POS_CROSS_2], - startPosIdToName[START_POS_CROSS_4] + startPosIdToName[START_POS_CROSS_4], + startPosIdToName[START_POS_CUSTOM] }; } @@ -826,13 +827,25 @@ int Rules::tryRecognizeStartPos( const vector& placementMoves, const int x_size, const int y_size, - const bool emptyIfFailed, - bool& randomized) { + bool& randomized, + vector* remainingMoves) { randomized = false; // Empty or unknown start pos is static by default + if (remainingMoves != nullptr) { + *remainingMoves = placementMoves; + } + int result = START_POS_EMPTY; + + if(placementMoves.empty()) return result; - if(placementMoves.empty()) return START_POS_EMPTY; + // If all placement moves are black, then it's a handicap game and the start pos is empty + for (const auto placementMove : placementMoves) { + if (placementMove.pla != C_BLACK) { + result = START_POS_CUSTOM; + break; + } + } - int result = emptyIfFailed ? START_POS_EMPTY : START_POS_CUSTOM; + if (result == START_POS_EMPTY) return result; const int stride = x_size + 1; auto placement = vector(stride * (y_size + 2), C_EMPTY); @@ -888,6 +901,7 @@ int Rules::tryRecognizeStartPos( auto detectRandomization = [&](const int expectedStartPos) -> void { auto staticStartPosMoves = generateStartPos(expectedStartPos, nullptr, x_size, y_size); + assert(remainingMoves != nullptr || placementMoves.size() == recognizedCrossesMoves.size()); assert(staticStartPosMoves.size() == recognizedCrossesMoves.size()); sortByLoc(staticStartPosMoves); @@ -899,6 +913,20 @@ int Rules::tryRecognizeStartPos( } } + if (remainingMoves != nullptr) { + for (const auto recognizedMove : recognizedCrossesMoves) { + bool removed = false; + for(auto it = remainingMoves->begin(); it != remainingMoves->end(); ++it) { + if (it->loc == recognizedMove.loc && it->pla == recognizedMove.pla) { + remainingMoves->erase(it); + removed = true; + break; + } + } + assert(removed); + } + } + result = expectedStartPos; }; diff --git a/cpp/game/rules.h b/cpp/game/rules.h index f22fef300..d2f6910c2 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -116,15 +116,16 @@ struct Rules { * @param placementMoves placement moves that we are trying to recognize. * @param x_size size of field * @param y_size size of field - * @param emptyIfFailed returns empty start pos if recognition is failed. It's useful for detecting start pos from SGF when handicap stones are placed * @param randomized if we recognize a start pos, but it doesn't match the strict position, set it up to `true` + * @param remainingMoves Holds moves that remain after start pos recognition, useful for SGF handling. + * If it's null (default), then it's assumed that all placement moves are used in the recognized start pos. */ static int tryRecognizeStartPos( const std::vector& placementMoves, int x_size, int y_size, - bool emptyIfFailed, - bool& randomized); + bool& randomized, + std::vector* remainingMoves = nullptr); friend std::ostream& operator<<(std::ostream& out, const Rules& rules); std::string toString(bool includeSgfDefinedProperties = true) const; @@ -171,10 +172,12 @@ struct Rules { startPosIdToName[START_POS_CROSS] = "CROSS"; startPosIdToName[START_POS_CROSS_2] = "CROSS_2"; startPosIdToName[START_POS_CROSS_4] = "CROSS_4"; + startPosIdToName[START_POS_CUSTOM] = "CUSTOM"; startPosNameToId["EMPTY"] = START_POS_EMPTY; startPosNameToId["CROSS"] = START_POS_CROSS; startPosNameToId["CROSS_2"] = START_POS_CROSS_2; startPosNameToId["CROSS_4"] = START_POS_CROSS_4; + startPosNameToId["CUSTOM"] = START_POS_CUSTOM; } } diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 6dda69b43..757b51465 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -708,7 +708,7 @@ Board SymmetryHelpers::getSymBoard(const Board& board, int symmetry) { sym_start_pos_moves.emplace_back(getSymLoc(x, y), start_pos_move.pla); } bool randomized; - symRules.startPos = Rules::tryRecognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, false, randomized); + symRules.startPos = Rules::tryRecognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, randomized); symRules.startPosIsRandom = randomized; } diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 63d736617..8fcd12388 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -53,7 +53,7 @@ void checkStartPosRecognition(const string& description, const int expectedStart void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { const auto generatedMoves = Rules::generateStartPos(startPos, startPosIsRandom ? &DOTS_RANDOM : nullptr, 39, 32); bool actualRandomized; - testAssert(startPos == Rules::tryRecognizeStartPos(generatedMoves, 39, 32, false, actualRandomized)); + testAssert(startPos == Rules::tryRecognizeStartPos(generatedMoves, 39, 32, actualRandomized)); // We can't reliably check in case of randomization is not detected because random generator can // generate static poses in rare cases. if (actualRandomized) { @@ -82,7 +82,7 @@ HASH: A130436FBD93FF473AB4F3B84DD304DB 1 . . X . )", {XYMove(2, 3, P_BLACK)}); - checkStartPosRecognition("Not enough dots for cross", 0, false, R"( + checkStartPosRecognition("Not enough dots for cross", Rules::START_POS_CUSTOM, false, R"( .... .xo. .o.. diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index 2c3c4271f..de80182f3 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -29,15 +29,27 @@ void Tests::runSgfTests() { const BoardHistory hist = sgf->setupInitialBoardAndHist(rules, pla); const Board& board = hist.initialBoard; - if (rules.startPos == Rules::START_POS_EMPTY) { + bool randomized; + vector remainingPlacementMoves; + const int recognizedStartPos = Rules::tryRecognizeStartPos(sgf->placements, board.x_size, board.y_size, randomized, &remainingPlacementMoves); + testAssert(recognizedStartPos == rules.startPos); + testAssert(randomized == rules.startPosIsRandom); + + if (recognizedStartPos != Rules::START_POS_EMPTY) { + out << "startPos " << Rules::writeStartPosRule(recognizedStartPos); + if (randomized) { + out << " (randomized)"; + } + out << endl; + } + + if (!remainingPlacementMoves.empty()) { out << "placements" << endl; - for(int i = 0; i < sgf->placements.size(); i++) { - Move move = sgf->placements[i]; - out << PlayerIO::colorToChar(move.pla) << " " << Location::toString(move.loc,board) << endl; + for (const auto placementMove : remainingPlacementMoves) { + out << PlayerIO::colorToChar(placementMove.pla) << " " << Location::toString(placementMove.loc, board) << endl; } - } else { - out << "startPos " << Rules::writeStartPosRule(rules.startPos) << endl; } + out << "moves" << endl; for(int i = 0; i < sgf->moves.size(); i++) { Move move = sgf->moves[i]; @@ -155,18 +167,19 @@ void Tests::runSgfTests() { { const char* name = "Basic Dots Sgf parse test"; - string sgfStr = "(;FF[4]GM[40]CA[UTF-8]AP[katago]SZ[10:8]AB[ed][fe]AW[ee][fd];B[ef];W[de];B[df];W[hd];B[ce];W[hf];B[cd];W[ff];B[dc];W[cf];B[hb];W[ic];B[db];W[gg];B[da];W[bg];B[])"; + string sgfStr = "(;FF[4]GM[40]CA[UTF-8]AP[katago]SZ[10:8]AB[ed][fe][ef]AW[ee][fd];W[de];B[df];W[hd];B[ce];W[hf];B[cd];W[ff];B[dc];W[cf];B[hb];W[ic];B[db];W[gg];B[da];W[bg];B[])"; parseAndPrintSgfLinear(sgfStr); string expected = R"( Dots game xSize 10 ySize 8 -depth 18 +depth 17 komi 0 startPos CROSS -moves +placements X E3 +moves O D4 X D3 O H5 @@ -184,22 +197,22 @@ X D8 O B2 X ground Initial board hist -pla Black -HASH: BA8A444F3D6E9FC94A3F4A16C7D2DBA0 +pla White +HASH: 42AC4303D65557034CC3593CB26EA615 1 2 3 4 5 6 7 8 9 10 8 . . . . . . . . . . 7 . . . . . . . . . . 6 . . . . . . . . . . 5 . . . . X O . . . . 4 . . . . O X . . . . - 3 . . . . . . . . . . + 3 . . . . X . . . . . 2 . . . . . . . . . . 1 . . . . . . . . . . Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 White bonus score 0 -Presumed next pla Black +Presumed next pla White Game result 0 Empty 0 0 0 0 Last moves Final board hist @@ -220,7 +233,7 @@ Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 White bonus score 0 Presumed next pla White Game result 1 Black -1 1 0 0 -Last moves E3 D4 D3 H5 C4 H3 C5 F3 D6 C3 H7 J6 D7 G2 D8 B2 ground +Last moves D4 D3 H5 C4 H3 C5 F3 D6 C3 H7 J6 D7 G2 D8 B2 ground )"; expect(name,out,expected); } @@ -656,7 +669,6 @@ xSize 17 ySize 3 depth 5 komi -6.5 -placements moves X F1 O C1 @@ -733,6 +745,7 @@ xSize 9 ySize 9 depth 2 komi 7.5 +startPos CUSTOM placements X B7 X D7 @@ -761,7 +774,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koPOSITIONALscoreAREAtaxNONEsui1komi7.5 +Rules koPOSITIONALscoreAREAtaxNONEstartPosCUSTOMsui1komi7.5 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -790,7 +803,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koPOSITIONALscoreAREAtaxNONEsui1komi7.5 +Rules koPOSITIONALscoreAREAtaxNONEstartPosCUSTOMsui1komi7.5 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -827,7 +840,6 @@ xSize 5 ySize 5 depth 13 komi 24 -placements moves X C3 O C4 @@ -922,7 +934,6 @@ xSize 5 ySize 5 depth 7 komi 24 -placements moves X C3 X B4 @@ -1005,7 +1016,6 @@ xSize 37 ySize 37 depth 14 komi 0 -placements moves X D34 O AJ34 @@ -1169,7 +1179,7 @@ xSize 5 ySize 5 komi 12.5 hasRules true -rules koSIMPLEscoreTERRITORYtaxSEKIsui0komi12.5 +rules koSIMPLEscoreTERRITORYtaxSEKIstartPosCUSTOMsui0komi12.5 handicapValue 5 sgfWinner Black firstPlayerColor X From acc39a598ee6e981208022a7415bd4f2137aa251 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Mon, 29 Sep 2025 15:18:21 +0200 Subject: [PATCH 18/42] Introduce `dataBoardLenX`, `dataBoardLenY` config keys --- cpp/command/selfplay.cpp | 6 ++++-- cpp/command/writetrainingdata.cpp | 13 ++++++++----- cpp/game/common.h | 2 ++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index 3cb66062a..03551077d 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -93,6 +93,8 @@ int MainCmds::selfplay(const vector& args) { //Width and height of the board to use when writing data, typically 19 const int dataBoardLen = cfg.getInt("dataBoardLen",3,Board::MAX_LEN); + const int dataBoardLenX = cfg.contains(DATA_LEN_X_KEY) ? cfg.getInt(DATA_LEN_X_KEY,3,Board::MAX_LEN_X) : dataBoardLen; + const int dataBoardLenY = cfg.contains(DATA_LEN_Y_KEY) ? cfg.getInt(DATA_LEN_Y_KEY,3,Board::MAX_LEN_Y) : dataBoardLen; const bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); const int inputsVersion = @@ -137,7 +139,7 @@ int MainCmds::selfplay(const vector& args) { //Returns true if a new net was loaded. auto loadLatestNeuralNetIntoManager = - [inputsVersion,&manager,maxRowsPerTrainFile,firstFileRandMinProp,dataBoardLen, + [inputsVersion,&manager,maxRowsPerTrainFile,firstFileRandMinProp,dataBoardLenX,dataBoardLenY, &modelsDir,&outputDir,&logger,&cfg,numGameThreads, minBoardXSizeUsed,maxBoardXSizeUsed,minBoardYSizeUsed,maxBoardYSizeUsed](const string* lastNetName) -> bool { @@ -215,7 +217,7 @@ int MainCmds::selfplay(const vector& args) { //Note that this inputsVersion passed here is NOT necessarily the same as the one used in the neural net self play, it //simply controls the input feature version for the written data auto tdataWriter = new TrainingDataWriter( - tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLen, dataBoardLen, Global::uint64ToHexString(rand.nextUInt64())); + tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLenX, dataBoardLenY, Global::uint64ToHexString(rand.nextUInt64())); ofstream* sgfOut = nullptr; if(sgfOutputDir.length() > 0) { sgfOut = new ofstream(); diff --git a/cpp/command/writetrainingdata.cpp b/cpp/command/writetrainingdata.cpp index e397a8976..a0527b9ca 100644 --- a/cpp/command/writetrainingdata.cpp +++ b/cpp/command/writetrainingdata.cpp @@ -681,6 +681,8 @@ int MainCmds::writetrainingdata(const vector& args) { const int numTotalThreads = numWorkerThreads * numSearchThreads; const int dataBoardLen = cfg.getInt("dataBoardLen",3,Board::MAX_LEN); + const int dataBoardLenX = cfg.contains(DATA_LEN_X_KEY) ? cfg.getInt(DATA_LEN_X_KEY,3,Board::MAX_LEN_X) : dataBoardLen; + const int dataBoardLenY = cfg.contains(DATA_LEN_Y_KEY) ? cfg.getInt(DATA_LEN_Y_KEY,3,Board::MAX_LEN_Y) : dataBoardLen; const int maxApproxRowsPerTrainFile = cfg.getInt("maxApproxRowsPerTrainFile",1,100000000); const std::vector> allowedBoardSizes = @@ -814,8 +816,8 @@ int MainCmds::writetrainingdata(const vector& args) { (maxApproxRowsPerTrainFile * 4/3 + Board::MAX_PLAY_SIZE * 2 + 100), numBinaryChannels, numGlobalChannels, - dataBoardLen, - dataBoardLen, + dataBoardLenX, + dataBoardLenY, hasMetadataInput ) ); @@ -868,12 +870,13 @@ int MainCmds::writetrainingdata(const vector& args) { return; } - if(xySize.x > dataBoardLen || xySize.y > dataBoardLen) { + if(xySize.x > dataBoardLenX || xySize.y > dataBoardLenY) { logger.write( - "SGF board size > dataBoardLen in " + fileName + ":" + "SGF board sizeX > dataBoardLenX or sizeY > dataBoardLenY in " + fileName + ":" + " " + Global::intToString(xySize.x) + " " + Global::intToString(xySize.y) - + " " + Global::intToString(dataBoardLen) + + " " + Global::intToString(dataBoardLenX) + + " " + Global::intToString(dataBoardLenY) ); reportSgfDone(false,"SGFGreaterThanDataBoardLen"); return; diff --git a/cpp/game/common.h b/cpp/game/common.h index 95ca124b4..4daed7a5c 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -4,6 +4,8 @@ #include "../core/global.h" const std::string DOTS_KEY = "dots"; +const std::string DATA_LEN_X_KEY = "dataBoardLenX"; +const std::string DATA_LEN_Y_KEY = "dataBoardLenY"; const std::string DOTS_CAPTURE_EMPTY_BASE_KEY = "dotsCaptureEmptyBase"; const std::string DOTS_CAPTURE_EMPTY_BASES_KEY = "dotsCaptureEmptyBases"; const std::string START_POS_KEY = "startPos"; From 320c719a0291372f920cf87c17dc9d5e0472e49d Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Tue, 30 Sep 2025 17:02:12 +0200 Subject: [PATCH 19/42] [Python] Initial support for `pos_len_x`, `pos_len_y` in advance to `pos_len` (not finished) --- python/forward_model.py | 6 ++- python/humanslnet_server.py | 2 +- python/katago/game/features.py | 29 +++++++------ python/katago/game/gamestate.py | 21 ++++----- .../katago/train/data_processing_pytorch.py | 29 +++++-------- python/katago/train/load_model.py | 4 +- python/katago/train/metrics_pytorch.py | 35 +++++++-------- python/katago/train/model_pytorch.py | 39 ++++++----------- python/play.py | 12 +++--- python/test.py | 22 ++++------ python/train.py | 43 ++++++++----------- 11 files changed, 108 insertions(+), 134 deletions(-) diff --git a/python/forward_model.py b/python/forward_model.py index f919b0a90..f3b6d8b57 100644 --- a/python/forward_model.py +++ b/python/forward_model.py @@ -33,6 +33,8 @@ parser.add_argument('-npz', help='NPZ file to evaluate', required=True) parser.add_argument('-checkpoint', help='Checkpoint to test', required=False) parser.add_argument('-pos-len', help='Spatial length of expected training data', type=int, required=True) + parser.add_argument('-pos-len-x', help='Spatial width of expected training data (-pos-len if undefined)', type=int, required=False) + parser.add_argument('-pos-len-y', help='Spatial height of expected training data (-pos-len if undefined)', type=int, required=False) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-gpu-idx', help='GPU idx', type=int, required=False) @@ -42,6 +44,8 @@ def main(args): npz_file = args["npz"] checkpoint_file = args["checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len use_swa = args["use_swa"] gpu_idx = args["gpu_idx"] @@ -74,7 +78,7 @@ def main(args): # LOAD MODEL --------------------------------------------------------------------- - model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=False) + model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len_x, pos_len_y=pos_len_y, verbose=False) model_config = model.config batch = np.load(npz_file) diff --git a/python/humanslnet_server.py b/python/humanslnet_server.py index ca7b3e33f..93638b1c7 100644 --- a/python/humanslnet_server.py +++ b/python/humanslnet_server.py @@ -21,7 +21,7 @@ def main(): parser.add_argument('-device', help='Device to use, such as cpu or cuda:0', required=True) args = parser.parse_args() - model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device, pos_len=19, verbose=False) + model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device,verbose=False) if swa_model is not None: model = swa_model game_state = None diff --git a/python/katago/game/features.py b/python/katago/game/features.py index 246eefb30..d9451d4d5 100644 --- a/python/katago/game/features.py +++ b/python/katago/game/features.py @@ -6,29 +6,30 @@ from ..train import modelconfigs class Features: - def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): + def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int): self.config = config - self.pos_len = pos_len + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y self.version = modelconfigs.get_version(config) - self.pass_pos = self.pos_len * self.pos_len - self.bin_input_shape = [modelconfigs.get_num_bin_input_features(config), pos_len, pos_len] + self.pass_pos = self.pos_len_x * self.pos_len_y + self.bin_input_shape = [modelconfigs.get_num_bin_input_features(config), pos_len_x, pos_len_y] self.global_input_shape = [modelconfigs.get_num_global_input_features(config)] def xy_to_tensor_pos(self,x,y): - return y * self.pos_len + x + return y * self.pos_len_x + x def loc_to_tensor_pos(self,loc,board): if loc == Board.PASS_LOC: return self.pass_pos - return board.loc_y(loc) * self.pos_len + board.loc_x(loc) + return board.loc_y(loc) * self.pos_len_x + board.loc_x(loc) def tensor_pos_to_loc(self,pos,board): if pos == self.pass_pos: return None - pos_len = self.pos_len + pos_len = self.pos_len_x x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) x = pos % pos_len y = pos // pos_len if x < 0 or x >= x_size or y < 0 or y >= y_size: @@ -38,7 +39,7 @@ def tensor_pos_to_loc(self,pos,board): def sym_tensor_pos(self,pos,symmetry): if pos == self.pass_pos: return pos - pos_len = self.pos_len + pos_len = self.pos_len_x x = pos % pos_len y = pos // pos_len if symmetry >= 4: @@ -61,8 +62,8 @@ def iterLadders(self, board, f): x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) for y in range(y_size): for x in range(x_size): @@ -98,8 +99,8 @@ def fill_row_features(self, board, pla, opp, boards, moves, move_idx, rules, bin x_size = board.x_size y_size = board.y_size - assert(self.pos_len >= x_size) - assert(self.pos_len >= y_size) + assert(self.pos_len_x >= x_size) + assert(self.pos_len_y >= y_size) assert(len(boards) > 0) assert(board.zobrist == boards[move_idx].zobrist) diff --git a/python/katago/game/gamestate.py b/python/katago/game/gamestate.py index 896fbf079..b81300654 100644 --- a/python/katago/game/gamestate.py +++ b/python/katago/game/gamestate.py @@ -89,15 +89,16 @@ def redo(self): def get_input_features(self, features: Features): bin_input_data = np.zeros(shape=[1]+features.bin_input_shape, dtype=np.float32) global_input_data = np.zeros(shape=[1]+features.global_input_shape, dtype=np.float32) - pos_len = features.pos_len + pos_len_x = features.pos_len_x + pos_len_y = features.pos_len_y pla = self.board.pla opp = Board.get_opp(pla) move_idx = len(self.moves) # fill_row_features assumes N(HW)C order but we actually use NCHW order in the model, so work with it and revert bin_input_data = np.transpose(bin_input_data,axes=(0,2,3,1)) - bin_input_data = bin_input_data.reshape([1,pos_len*pos_len,-1]) + bin_input_data = bin_input_data.reshape([1,pos_len_x*pos_len_y,-1]) features.fill_row_features(self.board,pla,opp,self.boards,self.moves,move_idx,self.rules,bin_input_data,global_input_data,idx=0) - bin_input_data = bin_input_data.reshape([1,pos_len,pos_len,-1]) + bin_input_data = bin_input_data.reshape([1,pos_len_x,pos_len_y,-1]) bin_input_data = np.transpose(bin_input_data,axes=(0,3,1,2)) return bin_input_data, global_input_data @@ -106,7 +107,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non from ..train.model_pytorch import Model, ExtraOutputs with torch.no_grad(): model.eval() - features = Features(model.config, model.pos_len) + features = Features(model.config, model.pos_len_x, model.pos_len_y) bin_input_data, global_input_data = self.get_input_features(features) # Currently we don't actually do any symmetries @@ -195,7 +196,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non elif board.would_be_legal(board.pla,move): moves_and_probs1.append((move,policy1[i])) - ownership_flat = ownership.reshape([features.pos_len * features.pos_len]) + ownership_flat = ownership.reshape([features.pos_len_x * features.pos_len_y]) ownership_by_loc = [] board = self.board for y in range(board.y_size): @@ -207,7 +208,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: ownership_by_loc.append((loc,-ownership_flat[pos])) - scoring_flat = scoring.reshape([features.pos_len * features.pos_len]) + scoring_flat = scoring.reshape([features.pos_len_x * features.pos_len_y]) scoring_by_loc = [] board = self.board for y in range(board.y_size): @@ -219,7 +220,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: scoring_by_loc.append((loc,-scoring_flat[pos])) - futurepos0_flat = futurepos[0,:,:].reshape([features.pos_len * features.pos_len]) + futurepos0_flat = futurepos[0,:,:].reshape([features.pos_len_x * features.pos_len_y]) futurepos0_by_loc = [] board = self.board for y in range(board.y_size): @@ -231,7 +232,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: futurepos0_by_loc.append((loc,-futurepos0_flat[pos])) - futurepos1_flat = futurepos[1,:,:].reshape([features.pos_len * features.pos_len]) + futurepos1_flat = futurepos[1,:,:].reshape([features.pos_len_x * features.pos_len_y]) futurepos1_by_loc = [] board = self.board for y in range(board.y_size): @@ -243,7 +244,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: futurepos1_by_loc.append((loc,-futurepos1_flat[pos])) - seki_flat = seki.reshape([features.pos_len * features.pos_len]) + seki_flat = seki.reshape([features.pos_len_x * features.pos_len_y]) seki_by_loc = [] board = self.board for y in range(board.y_size): @@ -255,7 +256,7 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non else: seki_by_loc.append((loc,-seki_flat[pos])) - seki_flat2 = seki2.reshape([features.pos_len * features.pos_len]) + seki_flat2 = seki2.reshape([features.pos_len_x * features.pos_len_y]) seki_by_loc2 = [] board = self.board for y in range(board.y_size): diff --git a/python/katago/train/data_processing_pytorch.py b/python/katago/train/data_processing_pytorch.py index 1ad18c9fd..78a50daeb 100644 --- a/python/katago/train/data_processing_pytorch.py +++ b/python/katago/train/data_processing_pytorch.py @@ -9,17 +9,8 @@ from ..train import modelconfigs -def read_npz_training_data( - npz_files, - batch_size: int, - world_size: int, - rank: int, - pos_len: int, - device, - randomize_symmetries: bool, - include_meta: bool, - model_config: modelconfigs.ModelConfig, -): +def read_npz_training_data(npz_files, batch_size: int, world_size: int, rank: int, pos_len_x: int, pos_len_y: int, device, + randomize_symmetries: bool, include_meta: bool, model_config: modelconfigs.ModelConfig): rand = np.random.default_rng(seed=list(os.urandom(12))) num_bin_features = modelconfigs.get_num_bin_input_features(model_config) num_global_features = modelconfigs.get_num_global_input_features(model_config) @@ -47,10 +38,10 @@ def load_npz_file(npz_file): binaryInputNCHW = np.unpackbits(binaryInputNCHWPacked,axis=2) assert len(binaryInputNCHW.shape) == 3 - assert binaryInputNCHW.shape[2] == ((pos_len * pos_len + 7) // 8) * 8 - binaryInputNCHW = binaryInputNCHW[:,:,:pos_len*pos_len] + assert binaryInputNCHW.shape[2] == ((pos_len_x * pos_len_y + 7) // 8) * 8 + binaryInputNCHW = binaryInputNCHW[:,:, :pos_len_x * pos_len_y] binaryInputNCHW = np.reshape(binaryInputNCHW, ( - binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len, pos_len + binaryInputNCHW.shape[0], binaryInputNCHW.shape[1], pos_len_x, pos_len_y )).astype(np.float32) assert binaryInputNCHW.shape[1] == num_bin_features @@ -98,10 +89,10 @@ def load_npz_file(npz_file): if randomize_symmetries: symm = int(rand.integers(0,8)) batch_binaryInputNCHW = apply_symmetry(batch_binaryInputNCHW, symm) - batch_policyTargetsNCMove = apply_symmetry_policy(batch_policyTargetsNCMove, symm, pos_len) + batch_policyTargetsNCMove = apply_symmetry_policy(batch_policyTargetsNCMove, symm, pos_len_x, pos_len_y) batch_valueTargetsNCHW = apply_symmetry(batch_valueTargetsNCHW, symm) if include_qvalues: - batch_qValueTargetsNCMove = apply_symmetry_policy(batch_qValueTargetsNCMove, symm, pos_len) + batch_qValueTargetsNCMove = apply_symmetry_policy(batch_qValueTargetsNCMove, symm, pos_len_x, pos_len_y) batch_binaryInputNCHW = batch_binaryInputNCHW.contiguous() batch_policyTargetsNCMove = batch_policyTargetsNCMove.contiguous() @@ -125,14 +116,14 @@ def load_npz_file(npz_file): yield batch -def apply_symmetry_policy(tensor, symm, pos_len): +def apply_symmetry_policy(tensor, symm, pos_len_x, pos_len_y): """Same as apply_symmetry but also handles the pass index""" batch_size = tensor.shape[0] channels = tensor.shape[1] - tensor_without_pass = tensor[:,:,:-1].view((batch_size, channels, pos_len, pos_len)) + tensor_without_pass = tensor[:,:,:-1].view((batch_size, channels, pos_len_x, pos_len_y)) tensor_transformed = apply_symmetry(tensor_without_pass, symm) return torch.cat(( - tensor_transformed.reshape(batch_size, channels, pos_len*pos_len), + tensor_transformed.reshape(batch_size, channels, pos_len_x * pos_len_y), tensor[:,:,-1:] ), dim=2) diff --git a/python/katago/train/load_model.py b/python/katago/train/load_model.py index 6642c758f..cbf582c9a 100644 --- a/python/katago/train/load_model.py +++ b/python/katago/train/load_model.py @@ -39,7 +39,7 @@ def load_swa_model_state_dict(state_dict): return swa_model_state_dict -def load_model(checkpoint_file, use_swa, device, pos_len=19, verbose=False): +def load_model(checkpoint_file, use_swa, device, pos_len_x=19, pos_len_y=19, verbose=False): from ..train.model_pytorch import Model from torch.optim.swa_utils import AveragedModel @@ -54,7 +54,7 @@ def load_model(checkpoint_file, use_swa, device, pos_len=19, verbose=False): model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config,pos_len) + model = Model(model_config, pos_len_x, pos_len_y) model.initialize() # Strip off any "module." from when the model was saved with DDP or other things diff --git a/python/katago/train/metrics_pytorch.py b/python/katago/train/metrics_pytorch.py index eb938c554..9b52a4567 100644 --- a/python/katago/train/metrics_pytorch.py +++ b/python/katago/train/metrics_pytorch.py @@ -25,14 +25,15 @@ class Metrics: def __init__(self, batch_size: int, world_size: int, raw_model: Model): self.n = batch_size self.world_size = world_size - self.pos_len = raw_model.pos_len - self.pos_area = raw_model.pos_len * raw_model.pos_len - self.policy_len = raw_model.pos_len * raw_model.pos_len + 1 + self.pos_len_x = raw_model.pos_len_x + self.pos_len_y = raw_model.pos_len_y + self.pos_area = raw_model.pos_len_x * raw_model.pos_len_y + self.policy_len = raw_model.pos_len_x * raw_model.pos_len_y + 1 self.value_len = 3 self.num_td_values = 3 self.num_futurepos_values = 2 self.num_seki_logits = 4 - self.scorebelief_len = 2 * (self.pos_len*self.pos_len + EXTRA_SCORE_DISTR_RADIUS) + self.scorebelief_len = 2 * (self.pos_len_x * self.pos_len_y + EXTRA_SCORE_DISTR_RADIUS) self.scoremean_multiplier = raw_model.scoremean_multiplier @@ -120,9 +121,9 @@ def loss_ownership_samplewise(self, pred_pretanh, target, weight, mask, mask_sum # This uses a formulation where each batch element cares about its average loss. # In particular this means that ownership loss predictions on small boards "count more" per spot. # Not unlike the way that policy and value loss are also equal-weighted by batch element. - assert pred_pretanh.shape == (self.n, 1, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_pretanh.shape == (self.n, 1, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) pred_logits = pred_pretanh.view(self.n,self.pos_area) * 2.0 target_probs = (1.0 + target.view(self.n,self.pos_area)) / 2.0 @@ -134,9 +135,9 @@ def loss_ownership_samplewise(self, pred_pretanh, target, weight, mask, mask_sum def loss_scoring_samplewise(self, pred_scoring, target, weight, mask, mask_sum_hw, global_weight): - assert pred_scoring.shape == (self.n, 1, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_scoring.shape == (self.n, 1, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) loss = torch.sum(torch.square(pred_scoring.squeeze(1) - target) * mask, dim=(1,2)) / mask_sum_hw @@ -154,9 +155,9 @@ def loss_futurepos_samplewise(self, pred_pretanh, target, weight, mask, mask_sum # causing some scaling with board size. So, I dunno, let's compromise and scale by sqrt(boardarea). # Also, the further out targets should be weighted a little less due to them being higher entropy # due to simply being farther in the future, so multiply by [1,0.25]. - assert pred_pretanh.shape == (self.n, self.num_futurepos_values, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.num_futurepos_values, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_pretanh.shape == (self.n, self.num_futurepos_values, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.num_futurepos_values, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) loss = torch.square(torch.tanh(pred_pretanh) - target) * mask.unsqueeze(1) loss = loss * constant_like([1.0,0.25], loss).view(1,2,1,1) @@ -166,10 +167,10 @@ def loss_futurepos_samplewise(self, pred_pretanh, target, weight, mask, mask_sum def loss_seki_samplewise(self, pred_logits, target, target_ownership, weight, mask, mask_sum_hw, global_weight, is_training, skip_moving_update): assert self.num_seki_logits == 4 - assert pred_logits.shape == (self.n, self.num_seki_logits, self.pos_len, self.pos_len) - assert target.shape == (self.n, self.pos_len, self.pos_len) - assert target_ownership.shape == (self.n, self.pos_len, self.pos_len) - assert mask.shape == (self.n, self.pos_len, self.pos_len) + assert pred_logits.shape == (self.n, self.num_seki_logits, self.pos_len_x, self.pos_len_y) + assert target.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert target_ownership.shape == (self.n, self.pos_len_x, self.pos_len_y) + assert mask.shape == (self.n, self.pos_len_x, self.pos_len_y) assert mask_sum_hw.shape == (self.n,) owned_target = torch.square(target_ownership) diff --git a/python/katago/train/model_pytorch.py b/python/katago/train/model_pytorch.py index f4e0a93d1..4654ac934 100644 --- a/python/katago/train/model_pytorch.py +++ b/python/katago/train/model_pytorch.py @@ -1384,7 +1384,7 @@ def forward(self, x, mask, mask_sum_hw, mask_sum:float, extra_outputs: Optional[ class ValueHead(torch.nn.Module): - def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation, pos_len): + def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation, pos_len_x, pos_len_y): super(ValueHead, self).__init__() self.activation = activation self.conv1 = torch.nn.Conv2d(c_in, c_v1, kernel_size=1, padding="same", bias=False) @@ -1407,8 +1407,9 @@ def __init__(self, c_in, c_v1, c_v2, c_sv2, num_scorebeliefs, config, activation self.conv_futurepos = torch.nn.Conv2d(c_in, 2, kernel_size=1, padding="same", bias=False) self.conv_seki = torch.nn.Conv2d(c_in, 4, kernel_size=1, padding="same", bias=False) - self.pos_len = pos_len - self.scorebelief_mid = self.pos_len*self.pos_len + EXTRA_SCORE_DISTR_RADIUS + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y + self.scorebelief_mid = self.pos_len_x * self.pos_len_y + EXTRA_SCORE_DISTR_RADIUS self.scorebelief_len = self.scorebelief_mid * 2 self.num_scorebeliefs = num_scorebeliefs self.c_sv2 = c_sv2 @@ -1603,7 +1604,7 @@ def forward(self, input_meta, extra_outputs: Optional[ExtraOutputs]): class Model(torch.nn.Module): - def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): + def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int): super(Model, self).__init__() self.config = config @@ -1620,7 +1621,8 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.c_sv2 = config["sbv2_num_channels"] self.num_scorebeliefs = config["num_scorebeliefs"] self.num_total_blocks = len(self.block_kind) - self.pos_len = pos_len + self.pos_len_x = pos_len_x + self.pos_len_y = pos_len_y if config["version"] <= 12: self.td_score_multiplier = 20.0 @@ -1661,7 +1663,7 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): else: self.metadata_encoder = None - self.bin_input_shape = [22, pos_len, pos_len] + self.bin_input_shape = [22, pos_len_x, pos_len_y] self.global_input_shape = [19] self.blocks = torch.nn.ModuleList() @@ -1770,16 +1772,8 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.config, self.activation, ) - self.value_head = ValueHead( - self.c_trunk, - self.c_v1, - self.c_v2, - self.c_sv2, - self.num_scorebeliefs, - self.config, - self.activation, - self.pos_len, - ) + self.value_head = ValueHead(self.c_trunk, self.c_v1, self.c_v2, self.c_sv2, self.num_scorebeliefs, self.config, + self.activation, self.pos_len_x, self.pos_len_y) if self.has_intermediate_head: self.norm_intermediate_trunkfinal = NormMask(self.c_trunk, self.config, fixup_use_gamma=False, is_last_batchnorm=True) self.act_intermediate_trunkfinal = act(self.activation) @@ -1790,16 +1784,9 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len: int): self.config, self.activation, ) - self.intermediate_value_head = ValueHead( - self.c_trunk, - self.c_v1, - self.c_v2, - self.c_sv2, - self.num_scorebeliefs, - self.config, - self.activation, - self.pos_len, - ) + self.intermediate_value_head = ValueHead(self.c_trunk, self.c_v1, self.c_v2, self.c_sv2, + self.num_scorebeliefs, self.config, self.activation, + self.pos_len_x, self.pos_len_y) @property def device(self): diff --git a/python/play.py b/python/play.py index 543542482..10e69ab56 100644 --- a/python/play.py +++ b/python/play.py @@ -58,13 +58,13 @@ np.set_printoptions(linewidth=150) torch.set_printoptions(precision=7,sci_mode=False,linewidth=100000,edgeitems=1000,threshold=1000000) -model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=True) +model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len, pos_len_y=pos_len, verbose=True) if swa_model is not None: model = swa_model model_config = model.config model.eval() -features = Features(model_config, pos_len) +features = Features(model_config, pos_len_x, pos_len_y) # Moves ---------------------------------------------------------------- @@ -413,9 +413,9 @@ def add_attention_visualizations(extra_output_name, extra_output): def get_board_matrix_str(matrix, scale, formatstr): ret = "" - matrix = matrix.reshape([features.pos_len,features.pos_len]) - for y in range(features.pos_len): - for x in range(features.pos_len): + matrix = matrix.reshape([features.pos_len_x, features.pos_len_y]) + for y in range(features.pos_len_y): + for x in range(features.pos_len_x): ret += formatstr % (scale * matrix[y,x]) ret += " " ret += "\n" @@ -456,7 +456,7 @@ def get_policy_matrix_str(matrix, gs, scale, formatstr): ret = '' if command[0] == "boardsize": - if int(command[1]) > features.pos_len: + if int(command[1]) > features.pos_len_x: print("Warning: Trying to set incompatible boardsize %s (!= %d)" % (command[1], N), file=sys.stderr) ret = None board_size = int(command[1]) diff --git a/python/test.py b/python/test.py index b7e54837a..929180529 100644 --- a/python/test.py +++ b/python/test.py @@ -51,6 +51,8 @@ def main(args): config_file = args["config"] checkpoint_file = args["checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len batch_size = args["batch_size"] use_swa = args["use_swa"] max_batches = args["max_batches"] @@ -109,11 +111,11 @@ def main(args): model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config,pos_len) + model = Model(model_config,pos_len_x,pos_len_y) model.initialize() model.to(device) else: - model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len=pos_len, verbose=True) + model, swa_model, _ = load_model(checkpoint_file, use_swa, device=device, pos_len_x=pos_len_x, pos_len_y=pos_len_y, verbose=True) model_config = model.config metrics_obj = Metrics(batch_size,world_size,model) @@ -188,17 +190,11 @@ def log_metrics(prefix, metric_sums, metric_weights, metrics, sum_norms, num_sam num_samples_tested = 0 total_inference_time = 0.0 is_first_batch = True - for batch in data_processing_pytorch.read_npz_training_data( - val_files, - batch_size, - world_size, - rank, - pos_len, - device, - randomize_symmetries=True, - include_meta=model.get_has_metadata_encoder(), - model_config=model_config, - ): + for batch in data_processing_pytorch.read_npz_training_data(val_files, batch_size, world_size, rank, + pos_len_x, pos_len_y, + device, randomize_symmetries=True, + include_meta=model.get_has_metadata_encoder(), + model_config=model_config): if max_batches is not None and num_batches_tested >= max_batches: break diff --git a/python/train.py b/python/train.py index 3ee3f3c29..7a6850399 100755 --- a/python/train.py +++ b/python/train.py @@ -64,8 +64,10 @@ optional_args.add_argument('-exportprefix', help='Prefix to append to names of models', required=False) optional_args.add_argument('-initial-checkpoint', help='If no training checkpoint exists, initialize from this checkpoint', required=False) - required_args.add_argument('-pos-len', help='Spatial edge length of expected training data, e.g. 19 for 19x19 Go', type=int, required=True) required_args.add_argument('-batch-size', help='Per-GPU batch size to use for training', type=int, required=True) + required_args.add_argument('-pos-len', help='Spatial edge length of expected training data, e.g. 19 for 19x19 Go', type=int, required=True) + optional_args.add_argument('-pos-len-x', help='Spatial width of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + optional_args.add_argument('-pos-len-y', help='Spatial height of expected training data. If undefined, `-pos-len` is used', type=int, required=False) optional_args.add_argument('-samples-per-epoch', help='Number of data samples to consider as one epoch', type=int, required=False) optional_args.add_argument('-model-kind', help='String name for what model config to use', required=False) optional_args.add_argument('-lr-scale', help='LR multiplier on the hardcoded schedule', type=float, required=False) @@ -153,6 +155,8 @@ def main(rank: int, world_size: int, args, multi_gpu_device_ids, readpipes, writ initial_checkpoint = args["initial_checkpoint"] pos_len = args["pos_len"] + pos_len_x = args["pos_len_x"] or pos_len + pos_len_y = args["pos_len_y"] or pos_len batch_size = args["batch_size"] samples_per_epoch = args["samples_per_epoch"] model_kind = args["model_kind"] @@ -443,7 +447,7 @@ def load(): assert model_kind is not None, "Model kind is none or unspecified but the model is being created fresh" model_config = modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len) + raw_model = Model(model_config,pos_len_x,pos_len_y) raw_model.initialize() raw_model.to(device) @@ -478,7 +482,7 @@ def load(): state_dict = torch.load(path_to_load_from, map_location=device) model_config = state_dict["config"] if "config" in state_dict else modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len) + raw_model = Model(model_config,pos_len_x,pos_len_y) raw_model.initialize() train_state = {} @@ -1045,17 +1049,11 @@ def detensorify_metrics(metrics): logging.info("This subepoch, using files: " + str(train_files_to_use)) logging.info("Currently up to data row " + str(train_state["total_num_data_rows"])) lookahead_counter = 0 - for batch in data_processing_pytorch.read_npz_training_data( - train_files_to_use, - batch_size, - world_size, - rank, - pos_len=pos_len, - device=device, - randomize_symmetries=True, - include_meta=raw_model.get_has_metadata_encoder(), - model_config=model_config - ): + for batch in data_processing_pytorch.read_npz_training_data(train_files_to_use, batch_size, world_size, + rank, pos_len_x, pos_len_y, + device=device, randomize_symmetries=True, + include_meta=raw_model.get_has_metadata_encoder(), + model_config=model_config): optimizer.zero_grad(set_to_none=True) extra_outputs = None # if raw_model.get_has_metadata_encoder(): @@ -1253,17 +1251,12 @@ def detensorify_metrics(metrics): val_metric_weights = defaultdict(float) val_samples = 0 t0 = time.perf_counter() - for batch in data_processing_pytorch.read_npz_training_data( - val_files, - batch_size, - world_size=1, # Only the main process validates - rank=0, # Only the main process validates - pos_len=pos_len, - device=device, - randomize_symmetries=True, - include_meta=raw_model.get_has_metadata_encoder(), - model_config=model_config - ): + for batch in data_processing_pytorch.read_npz_training_data(val_files, batch_size, world_size=1, rank=0, + pos_len_x=pos_len_x, pos_len_y=pos_len_y, + device=device, + randomize_symmetries=True, + include_meta=raw_model.get_has_metadata_encoder(), + model_config=model_config): model_outputs = ddp_model( batch["binaryInputNCHW"], batch["globalInputNC"], From 81c3f62a012317d9cbc19d5074da6397719c0370 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Tue, 30 Sep 2025 17:29:09 +0200 Subject: [PATCH 20/42] [Python] Fix `read_array_header` for different versions (1, 2) --- python/shuffle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/shuffle.py b/python/shuffle.py index 22f93a4d0..eda21e6c3 100755 --- a/python/shuffle.py +++ b/python/shuffle.py @@ -362,7 +362,11 @@ def get_numpy_npz_headers(filename): wasbad = True print("WARNING: bad file, skipping it: %s (bad array %s)" % (filename,subfilename)) else: - (shape, is_fortran, dtype) = np.lib.format._read_array_header(npyfile,version) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(npyfile) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(npyfile) + (shape, is_fortran, dtype) = header npzheaders[subfilename] = (shape, is_fortran, dtype) if wasbad: return None From afc7bbd5e5ade96af6b3dc74e1983ae772f40381 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Thu, 2 Oct 2025 20:59:53 +0200 Subject: [PATCH 21/42] Fix calculating of empty items that can be captured in `makeMoveAndCalculateCapturesAndBases` Refine `Rules` constructor for Go and Dots --- cpp/game/board.h | 9 ++-- cpp/game/dotsfield.cpp | 35 ++++++-------- cpp/game/rules.cpp | 30 ++++++------ cpp/game/rules.h | 7 ++- cpp/tests/testdotsbasic.cpp | 80 ++++++++++++++++---------------- cpp/tests/testdotsextra.cpp | 73 ++++++++++++++++++++--------- cpp/tests/testdotsstartposes.cpp | 4 +- cpp/tests/testdotsstress.cpp | 6 ++- cpp/tests/testdotsutils.cpp | 6 +-- cpp/tests/testdotsutils.h | 2 +- 10 files changed, 142 insertions(+), 110 deletions(-) diff --git a/cpp/game/board.h b/cpp/game/board.h index 67ce4c640..9a9cc392b 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -383,7 +383,7 @@ struct Board bool isMultiStoneSuicideLegal ) const; - void calculateOneMoveCaptureAndBasePositionsForDots(bool isSuicideLegal, std::vector& captures, std::vector& bases) const; + void calculateOneMoveCaptureAndBasePositionsForDots(std::vector& captures, std::vector& bases) const; //Run some basic sanity checks on the board state, throws an exception if not consistent, for testing/debugging void checkConsistency() const; @@ -468,8 +468,11 @@ struct Board Base createBaseAndUpdateStates(Player basePla, bool isReal); void updateScoreAndHashForTerritory(Loc loc, State state, Player basePla, bool rollback); void invalidateAdjacentEmptyTerritoryIfNeeded(Loc loc); - void makeMoveAndCalculateCapturesAndBases(Player pla, Loc loc, bool isSuicideLegal, - std::vector& captures, std::vector& bases) const; + void makeMoveAndCalculateCapturesAndBases( + Player pla, + Loc loc, + std::vector& captures, + std::vector& bases) const; void setGrounded(Loc loc); void clearGrounded(Loc loc); diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 27b4d6cfe..68ac22a9c 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -886,24 +886,22 @@ void Board::invalidateAdjacentEmptyTerritoryIfNeeded(const Loc loc) { void Board::makeMoveAndCalculateCapturesAndBases( const Player pla, const Loc loc, - const bool isSuicideLegal, vector& captures, - vector& bases - ) const { - if(isLegal(loc, pla, isSuicideLegal, false)) { + vector& bases) const { + if(isLegal(loc, pla, rules.multiStoneSuicideLegal, false)) { MoveRecord moveRecord = const_cast(this)->playMoveRecordedDots(loc, pla); - if(!moveRecord.bases.empty()) { - if(moveRecord.bases[0].pla == pla) { - // Better handling of empty bases? - captures[loc] = captures[loc] | moveRecord.bases[0].pla; - } - } - for(Base& base: moveRecord.bases) { - for(const Loc& rollbackLoc: base.rollback_locations) { - // Consider empty bases as one move bases as well - bases[rollbackLoc] = bases[rollbackLoc] | base.pla; + if (base.is_real) { + const bool suicide = base.pla != pla; + if (!suicide) { + captures[loc] = static_cast(captures[loc] | base.pla); + } + + for(const Loc& rollbackLoc: base.rollback_locations) { + // Consider empty bases as well + bases[rollbackLoc] = static_cast(bases[rollbackLoc] | base.pla); + } } } @@ -911,7 +909,7 @@ void Board::makeMoveAndCalculateCapturesAndBases( } } -void Board::calculateOneMoveCaptureAndBasePositionsForDots(const bool isSuicideLegal, vector& captures, vector& bases) const { +void Board::calculateOneMoveCaptureAndBasePositionsForDots(vector& captures, vector& bases) const { const int fieldSize = (x_size + 1) * (y_size + 1); captures.resize(fieldSize); bases.resize(fieldSize); @@ -922,17 +920,14 @@ void Board::calculateOneMoveCaptureAndBasePositionsForDots(const bool isSuicideL const State state = getState(loc); const Color emptyTerritoryColor = getEmptyTerritoryColor(state); - if (emptyTerritoryColor != C_EMPTY) { - bases[loc] = bases[loc] | emptyTerritoryColor; - } // It doesn't make sense to calculate capturing when dot placed into own empty territory if (emptyTerritoryColor != P_BLACK) { - makeMoveAndCalculateCapturesAndBases(P_BLACK, loc, isSuicideLegal, captures, bases); + makeMoveAndCalculateCapturesAndBases(P_BLACK, loc, captures, bases); } if (emptyTerritoryColor != P_WHITE) { - makeMoveAndCalculateCapturesAndBases(P_WHITE, loc, isSuicideLegal, captures, bases); + makeMoveAndCalculateCapturesAndBases(P_WHITE, loc, captures, bases); } } } diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index af5a35b8f..6e46cc164 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -15,8 +15,20 @@ const Rules Rules::DEFAULT_GO = Rules(false); Rules::Rules() : Rules(false) {} -Rules::Rules(const bool initIsDots, const int startPos, const bool startPosIsRandom, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : - Rules(initIsDots, startPos, startPosIsRandom, 0, 0, 0, true, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} +Rules::Rules(const int startPos, const bool startPosIsRandom, const bool suicide, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : + Rules(true, startPos, startPosIsRandom, 0, 0, 0, suicide, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} + +Rules::Rules( + int kRule, + int sRule, + int tRule, + bool suic, + bool button, + int whbRule, + bool pOk, + float km +) : Rules(false, 0, false, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { +} Rules::Rules(const bool initIsDots) : Rules( initIsDots, @@ -25,7 +37,7 @@ Rules::Rules(const bool initIsDots) : Rules( initIsDots ? 0 : KO_POSITIONAL, initIsDots ? 0 : SCORING_AREA, initIsDots ? 0 : TAX_NONE, - true, + initIsDots, false, initIsDots ? 0 : WHB_ZERO, false, @@ -35,18 +47,6 @@ Rules::Rules(const bool initIsDots) : Rules( ) { } -Rules::Rules( - int kRule, - int sRule, - int tRule, - bool suic, - bool button, - int whbRule, - bool pOk, - float km -) : Rules(false, 0, false, kRule, sRule, tRule, suic, button, whbRule, pOk, km, false, false) { -} - Rules::Rules( bool isDots, int startPosRule, diff --git a/cpp/game/rules.h b/cpp/game/rules.h index d2f6910c2..862604468 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -59,8 +59,9 @@ struct Rules { bool friendlyPassOk; Rules(); - Rules(bool initIsDots, int startPos, bool startPosIsRandom, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); - explicit Rules(bool initIsDots); + // Constructor for Dots + Rules(int startPos, bool startPosIsRandom, bool suicide, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); + // Constructor for Go Rules( int koRule, int scoringRule, @@ -71,6 +72,7 @@ struct Rules { bool friendlyPassOk, float komi ); + explicit Rules(bool initIsDots); ~Rules(); bool operator==(const Rules& other) const; @@ -147,6 +149,7 @@ struct Rules { static const Hash128 ZOBRIST_DOTS_CAPTURE_EMPTY_BASES_HASH; private: + // General constructor Rules( bool isDots, int startPosRule, diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index 839ef5f82..6ad631ef5 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -7,12 +7,16 @@ using namespace std; using namespace TestCommon; -void checkDotsField(const string& description, const string& input, bool captureEmptyBases, bool freeCapturedDots, const std::function& check) { +void checkDotsField(const string& description, const string& input, + const std::function& check, + const bool suicide = Rules::DEFAULT_DOTS.multiStoneSuicideLegal, + const bool captureEmptyBases = Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, + const bool freeCapturedDots = Rules::DEFAULT_DOTS.dotsFreeCapturedDots) { cout << " " << description << endl; auto moveRecords = vector(); - Board initialBoard = parseDotsField(input, false, captureEmptyBases, freeCapturedDots, {}); + Board initialBoard = parseDotsField(input, false, suicide, captureEmptyBases, freeCapturedDots, {}); Board board = Board(initialBoard); @@ -26,14 +30,10 @@ void checkDotsField(const string& description, const string& input, bool capture testAssert(initialBoard.isEqualForTesting(board)); } -void checkDotsFieldDefault(const string& description, const string& input, const std::function& check) { - checkDotsField(description, input, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, check); -} - void Tests::runDotsFieldTests() { cout << "Running dots basic tests: " << endl; - checkDotsFieldDefault("Simple capturing", + checkDotsField("Simple capturing", R"( .x. xox @@ -43,7 +43,7 @@ xox testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Capturing with empty loc inside", + checkDotsField("Capturing with empty loc inside", R"( .oo. ox.. @@ -58,7 +58,7 @@ ox.. testAssert(!boardWithMoveRecords.isLegal(2, 1, P_WHITE)); }); - checkDotsFieldDefault("Triple capture", + checkDotsField("Triple capture", R"( .x.x. xo.ox @@ -69,7 +69,7 @@ xo.ox testAssert(3 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Base inside base inside base", + checkDotsField("Base inside base inside base", R"( .xxxxxxx. x..ooo..x @@ -116,7 +116,7 @@ testAssert(21 == boardWithMoveRecords.board.numWhiteCaptures); // Don't count al testAssert(6 == boardWithMoveRecords.board.numBlackCaptures); // Don't free the captured dot });*/ - checkDotsFieldDefault("Empty bases and suicide", + checkDotsField("Empty bases and suicide", R"( .x..o. x.xo.o @@ -144,7 +144,7 @@ x.xo.o .x..o. x.xo.o ...... -)", true, true, [](const BoardWithMoveRecords& boardWithMoveRecords) { +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { boardWithMoveRecords.playMove(1, 2, P_BLACK); boardWithMoveRecords.playMove(4, 2, P_WHITE); @@ -156,9 +156,9 @@ x.xo.o testAssert(0 == boardWithMoveRecords.board.numWhiteCaptures); testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); -}); +}, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, true, Rules::DEFAULT_DOTS.dotsFreeCapturedDots); - checkDotsFieldDefault("Capture wins suicide", + checkDotsField("Capture wins suicide", R"( .xo. xo.o @@ -169,7 +169,7 @@ xo.o testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Single dot doesn't break searching inside empty base", + checkDotsField("Single dot doesn't break searching inside empty base", R"( .oooo. o....o @@ -181,7 +181,7 @@ o....o testAssert(1 == boardWithMoveRecords.board.numBlackCaptures); }); - checkDotsFieldDefault("Ignored already surrounded territory", + checkDotsField("Ignored already surrounded territory", R"( ..xxx... .x...x.. @@ -198,7 +198,7 @@ x..x..x. testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Invalidation of empty base locations", + checkDotsField("Invalidation of empty base locations", R"( .oox. o..ox @@ -209,7 +209,7 @@ o..ox testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Invalidation of empty base locations ignoring borders", + checkDotsField("Invalidation of empty base locations ignoring borders", R"( ..xxx.... .x...x... @@ -229,7 +229,7 @@ x..x..xo. testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Dangling dots removing", + checkDotsField("Dangling dots removing", R"( .xx.xx. x..xo.x @@ -245,7 +245,7 @@ x..x..x testAssert(!boardWithMoveRecords.isLegal(3, 2, P_WHITE)); }); - checkDotsFieldDefault("Recalculate square during dangling dots removing", + checkDotsField("Recalculate square during dangling dots removing", R"( .ooo.. o...o. @@ -262,7 +262,7 @@ o...o. testAssert(2 == boardWithMoveRecords.board.numBlackCaptures); }); - checkDotsFieldDefault("Base sorting by size", + checkDotsField("Base sorting by size", R"( ..xxx.. .x...x. @@ -278,7 +278,7 @@ x.....x testAssert(2 == boardWithMoveRecords.board.numWhiteCaptures); }); - checkDotsFieldDefault("Number of legal moves", + checkDotsField("Number of legal moves", R"( .... .... @@ -287,7 +287,7 @@ x.....x testAssert(12 == boardWithMoveRecords.board.numLegalMoves); }); - checkDotsFieldDefault("Game over because of absence of legal moves", + checkDotsField("Game over because of absence of legal moves", R"( xxxx xo.x @@ -302,7 +302,7 @@ xx.x void Tests::runDotsGroundingTests() { cout << "Running dots grounding tests:" << endl; - checkDotsFieldDefault("Grounding propagation", + checkDotsField("Grounding propagation", R"( .x.. o.o. @@ -345,7 +345,7 @@ o.o. } ); - checkDotsFieldDefault("Grounding propagation with empty base", + checkDotsField("Grounding propagation with empty base", R"( ..x.. .x.x. @@ -373,7 +373,7 @@ o.o. testAssert(isGrounded(boardWithMoveRecords.getState(2, 3))); }); - checkDotsFieldDefault("Grounding score with grounded base", + checkDotsField("Grounding score with grounded base", R"( .x. xox @@ -386,7 +386,7 @@ xox } ); - checkDotsFieldDefault("Grounding score with ungrounded base", + checkDotsField("Grounding score with ungrounded base", R"( ..... ..o.. @@ -401,7 +401,7 @@ R"( } ); - checkDotsFieldDefault("Grounding score with grounded and ungrounded bases", + checkDotsField("Grounding score with grounded and ungrounded bases", R"( .x..... xox.o.. @@ -417,7 +417,7 @@ xox.o.. } ); - checkDotsFieldDefault("Grounding draw with ungrounded bases", + checkDotsField("Grounding draw with ungrounded bases", R"( ......... ..x...o.. @@ -436,7 +436,7 @@ R"( ); - checkDotsFieldDefault("Grounding of real and empty adjacent bases", + checkDotsField("Grounding of real and empty adjacent bases", R"( ..x.. ..x.. @@ -468,7 +468,7 @@ R"( } ); - checkDotsFieldDefault("Grounding of real base when it touches grounded", + checkDotsField("Grounding of real base when it touches grounded", R"( ..x.. ..x.. @@ -493,7 +493,7 @@ R"( } ); - checkDotsFieldDefault("Base inside base inside base and grounding score", + checkDotsField("Base inside base inside base and grounding score", R"( ....... ..ooo.. @@ -522,7 +522,7 @@ R"( testAssert(4 == boardWithMoveRecords.board.whiteScoreIfBlackGrounds); }); - checkDotsFieldDefault("Ground empty territory in case of dangling dots removing", + checkDotsField("Ground empty territory in case of dangling dots removing", R"( ......... ..xxx.... @@ -542,7 +542,7 @@ R"( //testAssert(isGrounded(boardWithMoveRecords.getState(4, 4))); }); - checkDotsFieldDefault("Simple", + checkDotsField("Simple", R"( ..... .xxo. @@ -568,7 +568,7 @@ R"( } ); - checkDotsFieldDefault("Draw", + checkDotsField("Draw", R"( .x... .xxo. @@ -586,7 +586,7 @@ R"( } ); - checkDotsFieldDefault("Bases", + checkDotsField("Bases", R"( ......... ..xx...x. @@ -605,7 +605,7 @@ R"( } ); - checkDotsFieldDefault("Multiple groups", + checkDotsField("Multiple groups", R"( ...... xxo..o @@ -628,7 +628,7 @@ x...oo } ); - checkDotsFieldDefault("Invalidate empty territory", + checkDotsField("Invalidate empty territory", R"( ...... ..oo.. @@ -657,13 +657,13 @@ R"( } ); - checkDotsFieldDefault("Don't invalidate empty territory for strong connection", + checkDotsField("Don't invalidate empty territory for strong connection", R"( .x. x.x .x. )", [](const BoardWithMoveRecords& boardWithMoveRecords) { - Board board = boardWithMoveRecords.board; + const Board board = boardWithMoveRecords.board; boardWithMoveRecords.playGroundingMove(P_BLACK); testAssert(0 == boardWithMoveRecords.board.numBlackCaptures); diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp index c68b803b4..2b97d8f8b 100644 --- a/cpp/tests/testdotsextra.cpp +++ b/cpp/tests/testdotsextra.cpp @@ -106,13 +106,16 @@ xo.. SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); cout << "Check dots symmetry with start pos" << endl; - auto board = Board(5, 4, Rules(true, Rules::START_POS_CROSS, false, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + const auto originalRules = Rules(Rules::DEFAULT_DOTS.startPos, false, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots); + auto board = Board(5, 4, originalRules); board.setStartPos(DOTS_RANDOM); board.playMoveAssumeLegal(Location::getLoc(1, 2, board.x_size), P_BLACK); const auto rotatedBoard = SymmetryHelpers::getSymBoard(board, SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_X); - auto expectedBoard = Board(4, 5, Rules(true, Rules::START_POS_CROSS, true, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + auto rulesAfterTransformation = originalRules; + rulesAfterTransformation.startPosIsRandom = true; + auto expectedBoard = Board(4, 5, rulesAfterTransformation); expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 2, expectedBoard.x_size), P_WHITE, true); expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 3, expectedBoard.x_size), P_BLACK, true); expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 3, expectedBoard.x_size), P_WHITE, true); @@ -247,24 +250,25 @@ R"( std::pair getCapturingAndBases( const string& boardData, - const bool suicideLegal, + const bool suicide, + const bool captureEmptyBases, const vector& extraMoves ) { - Board board = parseDotsFieldDefault(boardData, extraMoves); + const Board board = parseDotsField(boardData, false, suicide, captureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); const Board& copy(board); vector captures; vector bases; - copy.calculateOneMoveCaptureAndBasePositionsForDots(suicideLegal, captures, bases); + copy.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); std::ostringstream capturesStringStream; std::ostringstream basesStringStream; for (int y = 0; y < copy.y_size; y++) { for (int x = 0; x < copy.x_size; x++) { - Loc loc = Location::getLoc(x, y, copy.x_size); - Color captureColor = captures[loc]; + const Loc loc = Location::getLoc(x, y, copy.x_size); + const Color captureColor = captures[loc]; if (captureColor == C_WALL) { capturesStringStream << PlayerIO::colorToChar(P_BLACK) << PlayerIO::colorToChar(P_WHITE); } else { @@ -296,12 +300,13 @@ std::pair getCapturingAndBases( void checkCapturingAndBase( const string& title, const string& boardData, - const bool suicideLegal, - const vector& extraMoves, const string& expectedCaptures, - const string& expectedBases + const string& expectedBases, + const bool suicide = Rules::DEFAULT_DOTS.multiStoneSuicideLegal, + const bool captureEmptyBases = Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, + const vector& extraMoves = {} ) { - auto [capturing, bases] = getCapturingAndBases(boardData, suicideLegal, extraMoves); + auto [capturing, bases] = getCapturingAndBases(boardData, suicide, captureEmptyBases, extraMoves); cout << (" " + title + ": capturing").c_str() << endl; expect("", capturing, expectedCaptures); cout << (" " + title + ": bases").c_str() << endl; @@ -317,7 +322,7 @@ void Tests::runDotsCapturingTests() { .x...o. xox.oxo ....... -)", true, {}, R"( +)", R"( . . . . . . . . . . . . . . . X . . . O . @@ -337,7 +342,7 @@ xox ... oxo .o. -)", true, {}, R"( +)", R"( . . . . . . . XO . @@ -359,7 +364,7 @@ oxo .x. x.x .x. -)", true, {}, R"( +)", R"( . . . . . . . . . @@ -372,21 +377,45 @@ R"( ); checkCapturingAndBase( -"Empty base no suicide", +"Empty base can be broken", +R"( +.xx. +x..x +x.x. +oxo. +.o.. +)", R"( +. . . . +. . . . +. O . . +. . . . +. . . . +)", +R"( +. . . . +. X X . +. X . . +. O . . +. . . . +)" +); + + checkCapturingAndBase( +"No empty base capturing", R"( .x. x.x -.x. -)", false, {}, R"( +... +)", R"( . . . . . . . . . )", R"( . . . -. X . . . . -)" +. . . +)", Rules::DEFAULT_DOTS.multiStoneSuicideLegal, false ); checkCapturingAndBase( @@ -395,7 +424,7 @@ R"( .x. x.x ... -)", true, {}, R"( +)", R"( . . . . . . . X . @@ -404,7 +433,7 @@ R"( . . . . X . . . . -)" +)", Rules::DEFAULT_DOTS.multiStoneSuicideLegal, true ); checkCapturingAndBase( @@ -415,7 +444,7 @@ o.xo.x ox.ox. ox.ox. .o.x.. -)", true, {}, R"( +)", R"( . . . . . . . . . . . . . . . . . . diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 8fcd12388..80472d7a8 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -26,7 +26,7 @@ void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startP void checkStartPos(const string& description, const int startPos, const bool startPosIsRandom, const int x_size, const int y_size, const string& expectedBoard = "", const vector& extraMoves = {}) { cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; - auto board = Board(x_size, y_size, Rules(true, startPos, startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + auto board = Board(x_size, y_size, Rules(startPos, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); board.setStartPos(DOTS_RANDOM); for (const XYMove& extraMove : extraMoves) { board.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, board.x_size), extraMove.player); @@ -43,7 +43,7 @@ void checkStartPos(const string& description, const int startPos, const bool sta } void checkStartPosRecognition(const string& description, const int expectedStartPos, const int startPosIsRandom, const string& inputBoard) { - const Board board = parseDotsField(inputBoard, startPosIsRandom, false, false, {}); + const Board board = parseDotsField(inputBoard, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, {}); cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index 418c096fb..94a1f0616 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -199,7 +199,9 @@ void runDotsStressTestsInternal( Rand rand("runDotsStressTests"); - Rules rules = dotsGame ? Rules(dotsGame, startPos, startPosIsRandom, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) : Rules(); + Rules rules = dotsGame + ? Rules(startPos, startPosIsRandom, suicideAllowed, dotsCaptureEmptyBase, Rules::DEFAULT_DOTS.dotsFreeCapturedDots) + : Rules(); int numLegalMoves = x_size * y_size - rules.getNumOfStartPosStones(); vector randomMoves = vector(); @@ -316,7 +318,7 @@ void Tests::runDotsStressTests() { cout << "Running dots stress tests" << endl; cout << " Max territory" << endl; - auto board = Board(39, 32, Rules(true, Rules::START_POS_EMPTY, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); + auto board = Board(39, 32, Rules(Rules::START_POS_EMPTY, false, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); for(int y = 0; y < board.y_size; y++) { for(int x = 0; x < board.x_size; x++) { const Player pla = y == 0 || y == board.y_size - 1 || x == 0 || x == board.x_size - 1 ? P_BLACK : P_WHITE; diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp index 4fbd95d0d..5ac0aaa8f 100644 --- a/cpp/tests/testdotsutils.cpp +++ b/cpp/tests/testdotsutils.cpp @@ -3,10 +3,10 @@ using namespace std; Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { - return parseDotsField(input, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); + return parseDotsField(input, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); } -Board parseDotsField(const string& input, const bool startPosIsRandom, const bool captureEmptyBases, +Board parseDotsField(const string& input, const bool startPosIsRandom, const bool suicide, const bool captureEmptyBases, const bool freeCapturedDots, const vector& extraMoves) { int currentXSize = 0; int xSize = -1; @@ -27,7 +27,7 @@ Board parseDotsField(const string& input, const bool startPosIsRandom, const boo } } - Board result = Board::parseBoard(xSize, ySize, input, Rules(true, Rules::START_POS_EMPTY, startPosIsRandom, captureEmptyBases, freeCapturedDots)); + Board result = Board::parseBoard(xSize, ySize, input, Rules(Rules::START_POS_EMPTY, startPosIsRandom, suicide, captureEmptyBases, freeCapturedDots)); for(const XYMove& extraMove : extraMoves) { result.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, result.x_size), extraMove.player); } diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index 012ac7bc8..d6f372496 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -64,5 +64,5 @@ struct BoardWithMoveRecords { Board parseDotsFieldDefault(const string& input, const vector& extraMoves = {}); -Board parseDotsField(const string& input, bool startPosIsRandom, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); +Board parseDotsField(const string& input, bool startPosIsRandom, bool suicide, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); From 25d25e0107095e621355f42d788eddad0be70141 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Thu, 2 Oct 2025 21:02:40 +0200 Subject: [PATCH 22/42] Add an option to `printBoard` that allows to get rid of hash printing Because the hash depends on MAX_SIZE, and it's inconvenient on the experimental stage because the value might vary quite often --- cpp/game/board.cpp | 8 +++++--- cpp/game/board.h | 2 +- cpp/tests/testdotsstartposes.cpp | 15 +-------------- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 7b3a5c780..4e937bcfd 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -2744,10 +2744,12 @@ vector Location::parseSequence(const string& str, const Board& board) { return locs; } -void Board::printBoard(ostream& out, const Board& board, Loc markLoc, const vector* hist) { - if(hist != NULL) +void Board::printBoard(ostream& out, const Board& board, const Loc markLoc, const vector* hist, bool printHash) { + if(hist != nullptr) out << "MoveNum: " << hist->size() << " "; - out << "HASH: " << board.pos_hash << "\n"; + if (printHash) { + out << "HASH: " << board.pos_hash << "\n"; + } bool showCoords = board.isDots() || (board.x_size <= 50 && board.y_size <= 50); if(showCoords) { if (board.isDots()) { diff --git a/cpp/game/board.h b/cpp/game/board.h index 9a9cc392b..c2090abf8 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -393,7 +393,7 @@ struct Board static Board parseBoard(int xSize, int ySize, const std::string& s, const Rules& rules = Rules::DEFAULT_GO, char lineDelimiter = '\n'); std::string toString() const; - static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist); + static void printBoard(std::ostream& out, const Board& board, Loc markLoc, const std::vector* hist, bool printHash = true); static std::string toStringSimple(const Board& board, char lineDelimiter = '\n'); static nlohmann::json toJson(const Board& board); static Board ofJson(const nlohmann::json& data); diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 80472d7a8..4ddfa7d27 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -33,7 +33,7 @@ void checkStartPos(const string& description, const int startPos, const bool sta } std::ostringstream oss; - Board::printBoard(oss, board, Board::NULL_LOC, nullptr); + Board::printBoard(oss, board, Board::NULL_LOC, nullptr, false); if (!expectedBoard.empty()) { expect(description.c_str(), oss, expectedBoard); @@ -67,14 +67,12 @@ void Tests::runDotsStartPosTests() { Rand rand("runDotsStartPosTests"); checkStartPos("Cross on minimal size", Rules::START_POS_CROSS, false, 2, 2, R"( -HASH: EC100709447890A116AFC8952423E3DD 1 2 2 X O 1 O X )"); checkStartPos("Extra dots with cross (for instance, a handicap game)", Rules::START_POS_CROSS, false, 4, 4, R"( -HASH: A130436FBD93FF473AB4F3B84DD304DB 1 2 3 4 4 . . . . 3 . X O . @@ -97,7 +95,6 @@ HASH: A130436FBD93FF473AB4F3B84DD304DB )"); checkStartPos("Cross on odd size", Rules::START_POS_CROSS, false, 3, 3, R"( -HASH: 3B29F9557D2712A5BC982D218680927D 1 2 3 3 . X O 2 . O X @@ -105,7 +102,6 @@ HASH: 3B29F9557D2712A5BC982D218680927D )"); checkStartPos("Cross on standard size", Rules::START_POS_CROSS, false, 39, 32, R"( -HASH: 516E1ABBA0D6B69A0B3D17C9E34E52F7 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -142,14 +138,12 @@ HASH: 516E1ABBA0D6B69A0B3D17C9E34E52F7 )"); checkStartPos("Double cross on minimal size", Rules::START_POS_CROSS_2, false, 4, 2, R"( -HASH: 43FD769739F2AA27A8A1DAB1F4278229 1 2 3 4 2 X O O X 1 O X X O )"); checkStartPos("Double cross on odd size", Rules::START_POS_CROSS_2, false, 5, 3, R"( -HASH: AAA969B8135294A3D1ADAA07BEA9A987 1 2 3 4 5 3 . X O O X 2 . O X X O @@ -157,7 +151,6 @@ HASH: AAA969B8135294A3D1ADAA07BEA9A987 )"); checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 6, 4, R"( -HASH: D599CEA39B1378D29883145CA4C016FC 1 2 3 4 5 6 4 . . . . . . 3 . X O O X . @@ -166,7 +159,6 @@ HASH: D599CEA39B1378D29883145CA4C016FC )"); checkStartPos("Double cross", Rules::START_POS_CROSS_2, false, 7, 4, R"( -HASH: 249F175819EA8FDE47F8676E655A06DE 1 2 3 4 5 6 7 4 . . . . . . . 3 . . X O O X . @@ -175,7 +167,6 @@ HASH: 249F175819EA8FDE47F8676E655A06DE )"); checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, false, 39, 32, R"( -HASH: CAD72FD407955308CEFCBD7A9B14B35B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . @@ -212,7 +203,6 @@ HASH: CAD72FD407955308CEFCBD7A9B14B35B )"); checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 5, 5, R"( -HASH: 0C2DD637AAE5FA7E1469BF5829BE922B 1 2 3 4 5 5 X O . X O 4 O X . O X @@ -222,7 +212,6 @@ HASH: 0C2DD637AAE5FA7E1469BF5829BE922B )"); checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 7, 7, R"( -HASH: 89CBCA85E94AF1B6C376E6BCBC443A48 1 2 3 4 5 6 7 7 . . . . . . . 6 . X O . X O . @@ -234,7 +223,6 @@ HASH: 89CBCA85E94AF1B6C376E6BCBC443A48 )"); checkStartPos("Quadruple cross", Rules::START_POS_CROSS_4, false, 8, 8, R"( -HASH: 445D50D7A61C47CE2730BBB97A2B3C96 1 2 3 4 5 6 7 8 8 . . . . . . . . 7 . X O . . X O . @@ -247,7 +235,6 @@ HASH: 445D50D7A61C47CE2730BBB97A2B3C96 )"); checkStartPos("Quadruple cross on standard size", Rules::START_POS_CROSS_4, false, 39, 32, R"( -HASH: 2A9AE7F967F17B42D9B9CB45B735E9C6 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . From 785572a9327e160d52390e4a37a3c3148d9e2f89 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Thu, 2 Oct 2025 21:34:02 +0200 Subject: [PATCH 23/42] Use `winOrEffectiveDrawByGrounding` instead of `doesGroundingWinGame` It significantly decreases game duration because typically it doesn't make sense to play until the very end in case of zero komi and when all dots are grounded. --- cpp/game/boardhistory.cpp | 2 +- cpp/game/boardhistory.h | 6 ++- cpp/game/dotsboardhistory.cpp | 28 ++++++------- cpp/neuralnet/nneval.cpp | 3 +- cpp/program/play.cpp | 2 +- cpp/tests/testdotsbasic.cpp | 78 +++++++++++++++++++++++++++-------- 6 files changed, 82 insertions(+), 37 deletions(-) diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index c9eafe9b1..a45f77464 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -768,7 +768,7 @@ void BoardHistory::endGameIfAllPassAlive(const Board& board) { assert(rules.isDots == board.isDots()); if (rules.isDots) { - if (const float whiteScoreAfterGrounding = whiteScoreIfGroundingAlive(board); whiteScoreAfterGrounding != 0.0) { + if (const float whiteScoreAfterGrounding = whiteScoreIfGroundingAlive(board); whiteScoreAfterGrounding != std::numeric_limits::quiet_NaN()) { setFinalScoreAndWinner(whiteScoreAfterGrounding); isScored = true; isNoResult = false; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index a3ab88d99..92616cf36 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -175,8 +175,10 @@ struct BoardHistory { void endGameIfAllPassAlive(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); - bool doesGroundingWinGame(const Board& board, Player pla) const; - bool doesGroundingWinGame(const Board& board, Player pla, float& whiteScore) const; + // Effective draw is when there are no ungrounded dots on the field (disregarding Komi) + // We can consider grounding in this case because the further game typically doesn't make sense. + bool winOrEffectiveDrawByGrounding(const Board& board, Player pla, bool considerDraw = true) const; + // Return > 0 if white wins by grounding, < 0 if black wins by grounding, 0 if there are no ungrounded dots and Nan otherwise float whiteScoreIfGroundingAlive(const Board& board) const; void endAndScoreGameNow(const Board& board, Color area[Board::MAX_ARR_SIZE]); diff --git a/cpp/game/dotsboardhistory.cpp b/cpp/game/dotsboardhistory.cpp index c6a77f364..e3f91cc76 100644 --- a/cpp/game/dotsboardhistory.cpp +++ b/cpp/game/dotsboardhistory.cpp @@ -7,7 +7,7 @@ int BoardHistory::countDotsScoreWhiteMinusBlack(const Board& board, Color area[B const float whiteScore = whiteScoreIfGroundingAlive(board); Color groundingPlayer = C_EMPTY; - if (whiteScore > 0.0f) { + if (whiteScore >= 0.0f) { groundingPlayer = C_WHITE; } else if (whiteScore < 0.0f) { groundingPlayer = C_BLACK; @@ -15,34 +15,34 @@ int BoardHistory::countDotsScoreWhiteMinusBlack(const Board& board, Color area[B return board.calculateOwnershipAndWhiteScore(area, groundingPlayer); } -bool BoardHistory::doesGroundingWinGame(const Board& board, const Player pla) const { - float whiteScore; - return doesGroundingWinGame(board, pla, whiteScore); -} - -bool BoardHistory::doesGroundingWinGame(const Board& board, const Player pla, float& whiteScore) const { +bool BoardHistory::winOrEffectiveDrawByGrounding(const Board& board, const Player pla, const bool considerDraw) const { assert(rules.isDots); - whiteScore = whiteScoreIfGroundingAlive(board); - return pla == P_WHITE && whiteScore > 0.0f || pla == P_BLACK && whiteScore < 0.0f; + const float whiteScore = whiteScoreIfGroundingAlive(board); + return considerDraw && whiteScore == 0.0f || pla == P_WHITE && whiteScore > 0.0f || pla == P_BLACK && whiteScore < 0.0f; } float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { assert(rules.isDots); - if (const float fullWhiteScoreIfBlackGrounds = + const float fullWhiteScoreIfBlackGrounds = static_cast(board.whiteScoreIfBlackGrounds) + whiteBonusScore + whiteHandicapBonusScore + rules.komi; - fullWhiteScoreIfBlackGrounds < 0.0f) { + if (fullWhiteScoreIfBlackGrounds < 0.0f) { // Black already won the game return fullWhiteScoreIfBlackGrounds; } - if (const float fullBlackScoreIfWhiteGrounds = + const float fullBlackScoreIfWhiteGrounds = static_cast(board.blackScoreIfWhiteGrounds) - whiteBonusScore - whiteHandicapBonusScore - rules.komi; - fullBlackScoreIfWhiteGrounds < 0.0f) { + if (fullBlackScoreIfWhiteGrounds < 0.0f) { // White already won the game return -fullBlackScoreIfWhiteGrounds; } - return 0.0f; + if (fullWhiteScoreIfBlackGrounds == 0.0f && fullBlackScoreIfWhiteGrounds == 0.0f) { + // Draw by grounding + return 0.0f; + } + + return std::numeric_limits::quiet_NaN(); } \ No newline at end of file diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index da472a224..cd1082f2f 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -861,7 +861,8 @@ void NNEvaluator::evaluate( bool legal; if (loc == Board::PASS_LOC && history.rules.isDots) { // We need at least one legal loc, so choose grounding if it wins the game or there are no legal pos moves. - legal = legalCount == 0 || history.doesGroundingWinGame(board, nextPlayer); + // Also, choose grounding in case of effective draw because the further game makes no sense. + legal = legalCount == 0 || history.winOrEffectiveDrawByGrounding(board, nextPlayer); } else { legal = history.isLegal(board,loc,nextPlayer); } diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 1e240f2d7..e13f8b318 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -831,7 +831,7 @@ static void logSearch(Search* bot, Logger& logger, Loc loc, OtherGameProperties static Loc chooseRandomForkingMove(const NNOutput* nnOutput, const Board& board, const BoardHistory& hist, Player pla, Rand& gameRand, Loc banMove) { double r = gameRand.nextDouble(); - bool allowPass = !hist.rules.isDots || hist.doesGroundingWinGame(board, pla); + bool allowPass = !hist.rules.isDots || hist.winOrEffectiveDrawByGrounding(board, pla); //70% of the time, do a random temperature 1 policy move if(r < 0.70) return PlayUtils::chooseRandomPolicyMove(nnOutput, board, hist, pla, gameRand, 1.0, allowPass, banMove); diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index 6ad631ef5..4bfaecea8 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -688,11 +688,59 @@ void Tests::runDotsBoardHistoryGroundingTests() { .... )"); const auto boardHistory = BoardHistory(board); - float whiteScoreAfterGrounding; - testAssert(!boardHistory.doesGroundingWinGame(board, P_BLACK, whiteScoreAfterGrounding)); - testAssert(!boardHistory.doesGroundingWinGame(board, P_WHITE, whiteScoreAfterGrounding)); - testAssert(0.0f == whiteScoreAfterGrounding); - testAssert(0.0f == boardHistory.whiteScoreIfGroundingAlive(board)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, false)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, false)); + + // No draw because there are some ungrounded dots + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.xo. +.xo. +.ox. +.ox. +)"); + const auto boardHistory = BoardHistory(board); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, false)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, false)); + + // Effective draw because all dots are grounded + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.x.... +xox... +....o. +...oxo +...... +)", {XYMove(1, 2, P_BLACK), XYMove(4, 4, P_WHITE)}); + const auto boardHistory = BoardHistory(board); + + // Also effective draw because all bases are grounded + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); + } + + { + const Board board = parseDotsFieldDefault(R"( +.x.... +xox.x. +...... +....o. +.o.oxo +...... +)", {XYMove(1, 2, P_BLACK), XYMove(4, 5, P_WHITE)}); + const auto boardHistory = BoardHistory(board); + + // No effective draw because there are ungrounded dots + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK, true)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE, true)); } { @@ -705,10 +753,8 @@ void Tests::runDotsBoardHistoryGroundingTests() { board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_WHITE); testAssert(1 == board.numBlackCaptures); const auto boardHistory = BoardHistory(board); - float whiteScoreAfterGrounding; - testAssert(!boardHistory.doesGroundingWinGame(board, P_BLACK, whiteScoreAfterGrounding)); - testAssert(boardHistory.doesGroundingWinGame(board, P_WHITE, whiteScoreAfterGrounding)); - testAssert(1.0f == whiteScoreAfterGrounding); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); testAssert(1.0f == boardHistory.whiteScoreIfGroundingAlive(board)); } @@ -722,10 +768,8 @@ void Tests::runDotsBoardHistoryGroundingTests() { board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_BLACK); testAssert(1 == board.numWhiteCaptures); const auto boardHistory = BoardHistory(board); - float whiteScoreAfterGrounding; - testAssert(boardHistory.doesGroundingWinGame(board, P_BLACK, whiteScoreAfterGrounding)); - testAssert(!boardHistory.doesGroundingWinGame(board, P_WHITE, whiteScoreAfterGrounding)); - testAssert(-1.0f == whiteScoreAfterGrounding); + testAssert(boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); testAssert(-1.0f == boardHistory.whiteScoreIfGroundingAlive(board)); } @@ -740,11 +784,9 @@ void Tests::runDotsBoardHistoryGroundingTests() { board.playMoveAssumeLegal(Location::getLoc(2, 3, board.x_size), P_BLACK); testAssert(1 == board.numWhiteCaptures); const auto boardHistory = BoardHistory(board); - float whiteScoreAfterGrounding; - testAssert(!boardHistory.doesGroundingWinGame(board, P_BLACK, whiteScoreAfterGrounding)); - testAssert(!boardHistory.doesGroundingWinGame(board, P_WHITE, whiteScoreAfterGrounding)); - testAssert(0.0f == whiteScoreAfterGrounding); - testAssert(0.0f == boardHistory.whiteScoreIfGroundingAlive(board)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_BLACK)); + testAssert(!boardHistory.winOrEffectiveDrawByGrounding(board, P_WHITE)); + testAssert(std::isnan(boardHistory.whiteScoreIfGroundingAlive(board))); } } From 63ad75cb4ce36103e7b379eeffc0ba5072a907ad Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 4 Oct 2025 13:37:57 +0200 Subject: [PATCH 24/42] Introduce Global::FLOAT_EPS for more accurate float comparisons --- cpp/core/global.cpp | 8 ++++++++ cpp/core/global.h | 6 +++++- cpp/game/boardhistory.cpp | 6 +++--- cpp/game/dotsboardhistory.cpp | 12 +++++++----- cpp/game/rules.cpp | 2 +- cpp/tests/testdotsstress.cpp | 4 ++-- 6 files changed, 26 insertions(+), 12 deletions(-) diff --git a/cpp/core/global.cpp b/cpp/core/global.cpp index b5620c022..a04d2eda6 100644 --- a/cpp/core/global.cpp +++ b/cpp/core/global.cpp @@ -706,3 +706,11 @@ double Global::roundDynamic(double x, int precision) { double inverseScale = pow(10.0,-roundingMagnitude); return roundStatic(x, inverseScale); } + +bool Global::isEqual(const float f1, const float f2) { + return std::fabs(f1 - f2) <= FLOAT_EPS; +} + +bool Global::isZero(const float f) { + return std::fabs(f) <= FLOAT_EPS; +} diff --git a/cpp/core/global.h b/cpp/core/global.h index 62c34fa12..ad302a6b4 100644 --- a/cpp/core/global.h +++ b/cpp/core/global.h @@ -159,7 +159,11 @@ namespace Global //Round x to this many decimal digits of precision double roundDynamic(double x, int precision); -} + // Float comparison + constexpr float FLOAT_EPS = std::numeric_limits::epsilon(); + bool isEqual(float f1, float f2); + bool isZero(float f); +} // namespace Global struct StringError : public std::exception { std::string message; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index a45f77464..400d315ff 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -722,11 +722,11 @@ int BoardHistory::countTerritoryAreaScoreWhiteMinusBlack(const Board& board, Col return score; } -void BoardHistory::setFinalScoreAndWinner(float score) { +void BoardHistory::setFinalScoreAndWinner(const float score) { finalWhiteMinusBlackScore = score; - if(finalWhiteMinusBlackScore > 0.0f) + if(finalWhiteMinusBlackScore > Global::FLOAT_EPS) winner = C_WHITE; - else if(finalWhiteMinusBlackScore < 0.0f) + else if(finalWhiteMinusBlackScore < -Global::FLOAT_EPS) winner = C_BLACK; else winner = C_EMPTY; diff --git a/cpp/game/dotsboardhistory.cpp b/cpp/game/dotsboardhistory.cpp index e3f91cc76..aed74dec0 100644 --- a/cpp/game/dotsboardhistory.cpp +++ b/cpp/game/dotsboardhistory.cpp @@ -7,7 +7,7 @@ int BoardHistory::countDotsScoreWhiteMinusBlack(const Board& board, Color area[B const float whiteScore = whiteScoreIfGroundingAlive(board); Color groundingPlayer = C_EMPTY; - if (whiteScore >= 0.0f) { + if (whiteScore >= 0.0f) { // Don't use EPS comparison because in case of zero result, there is a draw and any player can ground groundingPlayer = C_WHITE; } else if (whiteScore < 0.0f) { groundingPlayer = C_BLACK; @@ -19,7 +19,9 @@ bool BoardHistory::winOrEffectiveDrawByGrounding(const Board& board, const Playe assert(rules.isDots); const float whiteScore = whiteScoreIfGroundingAlive(board); - return considerDraw && whiteScore == 0.0f || pla == P_WHITE && whiteScore > 0.0f || pla == P_BLACK && whiteScore < 0.0f; + return considerDraw && Global::isZero(whiteScore) || + pla == P_WHITE && whiteScore > Global::FLOAT_EPS || + pla == P_BLACK && whiteScore < -Global::FLOAT_EPS; } float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { @@ -27,19 +29,19 @@ float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { const float fullWhiteScoreIfBlackGrounds = static_cast(board.whiteScoreIfBlackGrounds) + whiteBonusScore + whiteHandicapBonusScore + rules.komi; - if (fullWhiteScoreIfBlackGrounds < 0.0f) { + if (fullWhiteScoreIfBlackGrounds < -Global::FLOAT_EPS) { // Black already won the game return fullWhiteScoreIfBlackGrounds; } const float fullBlackScoreIfWhiteGrounds = static_cast(board.blackScoreIfWhiteGrounds) - whiteBonusScore - whiteHandicapBonusScore - rules.komi; - if (fullBlackScoreIfWhiteGrounds < 0.0f) { + if (fullBlackScoreIfWhiteGrounds < -Global::FLOAT_EPS) { // White already won the game return -fullBlackScoreIfWhiteGrounds; } - if (fullWhiteScoreIfBlackGrounds == 0.0f && fullBlackScoreIfWhiteGrounds == 0.0f) { + if (Global::isZero(fullWhiteScoreIfBlackGrounds) && Global::isZero(fullBlackScoreIfWhiteGrounds)) { // Draw by grounding return 0.0f; } diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index 6e46cc164..67532b3c2 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -111,7 +111,7 @@ bool Rules::equals(const Rules& other, const bool ignoreSgfDefinedProps) const { } bool Rules::gameResultWillBeInteger() const { - bool komiIsInteger = ((int)komi) == komi; + const bool komiIsInteger = Global::isEqual(std::floor(komi), komi); return komiIsInteger != hasButton; } diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index 94a1f0616..ae3cfa2d7 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -281,9 +281,9 @@ void runDotsStressTestsInternal( } movesCount += currentGameMovesCount; - if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > 0.0f) { + if (float whiteScore = board.numBlackCaptures - board.numWhiteCaptures + komi; whiteScore > Global::FLOAT_EPS) { whiteWinsCount++; - } else if (whiteScore < 0) { + } else if (whiteScore < -Global::FLOAT_EPS) { blackWinsCount++; } else { drawsCount++; From 3d60f5e376cbb4f18d7a9387a3747b79f86c9cff Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 4 Oct 2025 17:29:49 +0200 Subject: [PATCH 25/42] Refine Dots NN inputs and scoring --- cpp/neuralnet/modelversion.cpp | 6 ++- cpp/neuralnet/nninputs.cpp | 46 +++++++++++++---- cpp/neuralnet/nninputs.h | 71 ++++++++++++++------------ cpp/neuralnet/nninputsdots.cpp | 92 +++++++++++++++++++++++++++------- cpp/neuralnet/opencltuner.cpp | 11 ++++ 5 files changed, 163 insertions(+), 63 deletions(-) diff --git a/cpp/neuralnet/modelversion.cpp b/cpp/neuralnet/modelversion.cpp index 746c9a27f..a829e743a 100644 --- a/cpp/neuralnet/modelversion.cpp +++ b/cpp/neuralnet/modelversion.cpp @@ -34,7 +34,7 @@ static_assert(NNModelVersion::latestModelVersionImplemented == 17, ""); static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); int NNModelVersion::getInputsVersion(int modelVersion) { - if (modelVersion == 17) + if (modelVersion == defaultModelVersionForDots) return dotsInputsVersion; if(modelVersion >= 8 && modelVersion <= 16) return 7; @@ -52,7 +52,7 @@ int NNModelVersion::getInputsVersion(int modelVersion) { } int NNModelVersion::getNumSpatialFeatures(int modelVersion) { - if(modelVersion == 17) + if(modelVersion == defaultModelVersionForDots) return NNInputs::NUM_FEATURES_SPATIAL_V_DOTS; if(modelVersion >= 8 && modelVersion <= 16) return NNInputs::NUM_FEATURES_SPATIAL_V7; @@ -70,6 +70,8 @@ int NNModelVersion::getNumSpatialFeatures(int modelVersion) { } int NNModelVersion::getNumGlobalFeatures(int modelVersion) { + if(modelVersion == defaultModelVersionForDots) + return NNInputs::NUM_FEATURES_GLOBAL_V_DOTS; if(modelVersion >= 8 && modelVersion <= 16) return NNInputs::NUM_FEATURES_GLOBAL_V7; else if(modelVersion == 7) diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 757b51465..5e7629ded 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -239,16 +239,38 @@ void NNInputs::fillScoring( ) { std::fill_n(scoring, Board::MAX_ARR_SIZE, 0.0f); - if(!groupTax || board.isDots()) { - // TODO: probably it makes sense to implement more accurate scoring for Dots - // That includes dead, empty locations and empty base locations: - // - // Captured enemy's dot: 1.0f - // Captured enemy's empty loc: 0.75f - // Empty base loc: 0.5f - // Empty: 0.0f - // - // Also consider grounding dots? + if (board.isDots()) { + for(int y = 0; y(DotsSpatialFeature::COUNT); + constexpr int NUM_FEATURES_GLOBAL_V_DOTS = static_cast(DotsGlobalFeature::COUNT); Hash128 getHash( const Board& board, const BoardHistory& boardHistory, Player nextPlayer, diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index b99dc6859..6e517bc11 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -12,7 +12,7 @@ void NNInputs::fillRowVDots( assert(board.x_size <= nnXLen); assert(board.y_size <= nnYLen); std::fill_n(rowBin, NUM_FEATURES_SPATIAL_V_DOTS * nnXLen * nnYLen,false); - std::fill_n(rowGlobal, NUM_FEATURES_SPATIAL_V_DOTS, 0.0f); + std::fill_n(rowGlobal, NUM_FEATURES_GLOBAL_V_DOTS, 0.0f); const Player pla = nextPlayer; const Player opp = getOpp(pla); @@ -30,60 +30,114 @@ void NNInputs::fillRowVDots( posStride = 1; } + const Rules& rules = hist.rules; + vector captures; vector bases; - board.calculateOneMoveCaptureAndBasePositionsForDots(hist.rules.multiStoneSuicideLegal, captures, bases); + board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + int deadDotsCount = 0; + + auto setSpatial = [&](const int pos, const DotsSpatialFeature spatialFeature) { + setRowBin(rowBin, pos, static_cast(spatialFeature), 1.0f, posStride, featureStride); + }; + + auto setGlobal = [&](const DotsGlobalFeature globalFeature, const float value = 1.0f) { + rowGlobal[static_cast(globalFeature)] = value; + }; for(int y = 0; y(xSize * ySize); + //Bound komi just in case + if(selfKomi > bArea+NNPos::KOMI_CLIP_RADIUS) + selfKomi = bArea+NNPos::KOMI_CLIP_RADIUS; + if(selfKomi < -bArea-NNPos::KOMI_CLIP_RADIUS) + selfKomi = -bArea-NNPos::KOMI_CLIP_RADIUS; + setGlobal(DotsGlobalFeature::Komi, selfKomi / NNPos::KOMI_CLIP_RADIUS); + + if (rules.multiStoneSuicideLegal) { + setGlobal(DotsGlobalFeature::Suicide); + } + + if (rules.dotsCaptureEmptyBases) { + setGlobal(DotsGlobalFeature::CaptureEmpty); + } + + if (const int startPos = rules.startPos; startPos >= Rules::START_POS_CROSS) { + setGlobal(DotsGlobalFeature::StartPosCross); + if (startPos >= Rules::START_POS_CROSS_2) { + setGlobal(DotsGlobalFeature::StartPosCross2); + if (startPos >= Rules::START_POS_CROSS_4) { + setGlobal(DotsGlobalFeature::StartPosCross4); } } } - int maxTurnsOfHistoryToInclude = 5; - if (hist.isGameFinished) { + if (rules.startPosIsRandom) { + setGlobal(DotsGlobalFeature::StartPosIsRandom); + } + if (hist.winOrEffectiveDrawByGrounding(board, pla)) { + // Train to better understand grounding + setGlobal(DotsGlobalFeature::WinByGrounding); } + + setGlobal(DotsGlobalFeature::FieldSizeKomiParity, 0.0f); // TODO: implement later } \ No newline at end of file diff --git a/cpp/neuralnet/opencltuner.cpp b/cpp/neuralnet/opencltuner.cpp index 58ae3eb4e..8c177a986 100644 --- a/cpp/neuralnet/opencltuner.cpp +++ b/cpp/neuralnet/opencltuner.cpp @@ -3462,6 +3462,17 @@ void OpenCLTuner::autoTuneEverything( modelInfo.modelVersion = 16; modelInfos.push_back(modelInfo); } + { + ModelInfoForTuning modelInfo; + modelInfo.maxConvChannels1x1 = 512; + modelInfo.maxConvChannels3x3 = 512; + modelInfo.trunkNumChannels = 512; + modelInfo.midNumChannels = 256; + modelInfo.regularNumChannels = 192; + modelInfo.gpoolNumChannels = 64; + modelInfo.modelVersion = 17; + modelInfos.push_back(modelInfo); + } for(ModelInfoForTuning modelInfo : modelInfos) { int nnXLen = NNPos::MAX_BOARD_LEN_X; From c420e734021797de1f90d447a219c885db142917 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 16:10:36 +0100 Subject: [PATCH 26/42] Introduce `endGameIfNoLegalMoves` Rename `numLegalMoves` to `numLegalMovesIfSuiAllowed` because it's an expensive operation to calculate the number if sui is on --- cpp/command/misc.cpp | 1 + cpp/game/board.cpp | 6 +++--- cpp/game/board.h | 4 +++- cpp/game/boardhistory.cpp | 13 +++++++++++++ cpp/game/boardhistory.h | 1 + cpp/game/dotsfield.cpp | 26 ++++++++++++++++++-------- cpp/neuralnet/nninputs.cpp | 2 +- cpp/program/play.cpp | 2 ++ cpp/program/playutils.cpp | 1 + cpp/search/localpattern.cpp | 2 +- cpp/tests/testdotsbasic.cpp | 32 ++++++++++++++++++++++++++++---- cpp/tests/testdotsstress.cpp | 4 ++-- 12 files changed, 74 insertions(+), 20 deletions(-) diff --git a/cpp/command/misc.cpp b/cpp/command/misc.cpp index e118b80e1..255a61498 100644 --- a/cpp/command/misc.cpp +++ b/cpp/command/misc.cpp @@ -224,6 +224,7 @@ int MainCmds::evalrandominits(const vector& args) { pla = getOpp(pla); hist.endGameIfAllPassAlive(board); + hist.endGameIfNoLegalMoves(board); if(hist.isGameFinished) break; } diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 4e937bcfd..399fba149 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -145,7 +145,7 @@ Board::Board(const Board& other) { numWhiteCaptures = other.numWhiteCaptures; blackScoreIfWhiteGrounds = other.blackScoreIfWhiteGrounds; whiteScoreIfBlackGrounds = other.whiteScoreIfBlackGrounds; - numLegalMoves = other.numLegalMoves; + numLegalMovesIfSuiAllowed = other.numLegalMovesIfSuiAllowed; start_pos_moves = other.start_pos_moves; memcpy(adj_offsets, other.adj_offsets, sizeof(short)*8); visited_data.resize(other.visited_data.size(), false); @@ -184,7 +184,7 @@ void Board::init(const int xS, const int yS, const Rules& initRules) numWhiteCaptures = 0; blackScoreIfWhiteGrounds = 0; whiteScoreIfBlackGrounds = 0; - numLegalMoves = xS * yS; + numLegalMovesIfSuiAllowed = xS * yS; if (!rules.isDots) { chain_data.resize(MAX_ARR_SIZE); @@ -2508,7 +2508,7 @@ bool Board::isEqualForTesting(const Board& other, const bool checkNumCaptures, if(colors[i] != other.colors[i]) return false; } - if (numLegalMoves != other.numLegalMoves) { + if (numLegalMovesIfSuiAllowed != other.numLegalMovesIfSuiAllowed) { return false; } if (start_pos_moves.size() != other.start_pos_moves.size()) { diff --git a/cpp/game/board.h b/cpp/game/board.h index c2090abf8..3d1bc7d33 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -48,6 +48,8 @@ Color getEmptyTerritoryColor(State s); bool isGrounded(State state); +bool isTerritory(State s); + //Conversions for players and colors namespace PlayerIO { char colorToChar(Color c); @@ -421,7 +423,7 @@ struct Board // Offsets to add to get clockwise traverse short adj_offsets[8]; - int numLegalMoves; + int numLegalMovesIfSuiAllowed; //Every chain of stones has one of its stones arbitrarily designated as the head. std::vector chain_data; //For each head stone, the chaindata for the chain under that head. Undefined otherwise. diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 400d315ff..31599ae19 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -820,6 +820,19 @@ void BoardHistory::endGameIfAllPassAlive(const Board& board) { } } +void BoardHistory::endGameIfNoLegalMoves(const Board& board) { + if (board.numLegalMovesIfSuiAllowed == 0) { + for(int y = 0; y < board.y_size; y++) { + for(int x = 0; x < board.x_size; x++) { + const Loc loc = Location::getLoc(x, y, board.x_size); + assert(!board.isLegal(loc, P_BLACK, rules.multiStoneSuicideLegal, true)); + assert(!board.isLegal(loc, P_WHITE, rules.multiStoneSuicideLegal, true)); + } + } + endAndScoreGameNow(board); + } +} + void BoardHistory::setWinnerByResignation(Player pla) { isGameFinished = true; isPastNormalPhaseEnd = false; diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 92616cf36..103bb94da 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -173,6 +173,7 @@ struct BoardHistory { //Slightly expensive, check if the entire game is all pass-alive-territory, and if so, declare the game finished // For Dots game it's Grounding alive void endGameIfAllPassAlive(const Board& board); + void endGameIfNoLegalMoves(const Board& board); //Score the board as-is. If the game is already finished, and is NOT a no-result, then this should be idempotent. void endAndScoreGameNow(const Board& board); // Effective draw is when there are no ungrounded dots on the field (disregarding Komi) diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 68ac22a9c..0916242f2 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -284,7 +284,9 @@ void Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; pos_hash ^= hashValue; - numLegalMoves--; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed--; + } bool atLeastOneRealBaseIsGrounded = false; int unconnectedLocationsSize = 0; @@ -339,7 +341,9 @@ Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); const Hash128 hashValue = ZOBRIST_BOARD_HASH[loc][pla]; pos_hash ^= hashValue; - numLegalMoves--; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed--; + } bool atLeastOneRealBaseIsGrounded = false; int unconnectedLocationsSize = 0; @@ -356,7 +360,9 @@ Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool } else { colors[loc] = originalState; pos_hash ^= hashValue; - numLegalMoves++; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed++; + } return {}; } } @@ -431,7 +437,9 @@ void Board::undoDots(MoveRecord& moveRecord) { if (!isGroundingMove) { setState(moveRecord.loc, moveRecord.previousState); pos_hash ^= ZOBRIST_BOARD_HASH[moveRecord.loc][moveRecord.pla]; - numLegalMoves++; + if (rules.multiStoneSuicideLegal) { + numLegalMovesIfSuiAllowed++; + } } } @@ -844,10 +852,12 @@ void Board::updateScoreAndHashForTerritory(const Loc loc, const State state, con } if (currentColor == C_EMPTY) { - if (!rollback) { - numLegalMoves--; - } else { - numLegalMoves++; + if (rules.multiStoneSuicideLegal) { + if (!rollback) { + numLegalMovesIfSuiAllowed--; + } else { + numLegalMovesIfSuiAllowed++; + } } pos_hash ^= ZOBRIST_BOARD_HASH[loc][basePla]; } else if (currentColor == baseOppPla) { diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 5e7629ded..2677c9d67 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -765,7 +765,7 @@ Board SymmetryHelpers::getSymBoard(const Board& board, int symmetry) { } else { symBoard.numBlackCaptures = board.numBlackCaptures; symBoard.numWhiteCaptures = board.numWhiteCaptures; - symBoard.numLegalMoves = board.numLegalMoves; + symBoard.numLegalMovesIfSuiAllowed = board.numLegalMovesIfSuiAllowed; symBoard.blackScoreIfWhiteGrounds = board.blackScoreIfWhiteGrounds; symBoard.whiteScoreIfBlackGrounds = board.whiteScoreIfBlackGrounds; } diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index e13f8b318..8ee628d08 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -1550,8 +1550,10 @@ FinishedGameData* Play::runGame( for(int i = 0; iwaitUntilFalse(); if(shouldStop != nullptr && shouldStop()) diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index 8a03ff1ea..fada11223 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -256,6 +256,7 @@ void PlayUtils::initializeGameUsingPolicy( //Rarely, playing the random moves out this way will end the game if(doEndGameIfAllPassAlive) hist.endGameIfAllPassAlive(board); + hist.endGameIfNoLegalMoves(board); if(hist.isGameFinished) break; } diff --git a/cpp/search/localpattern.cpp b/cpp/search/localpattern.cpp index e66fdaa44..5124182bf 100644 --- a/cpp/search/localpattern.cpp +++ b/cpp/search/localpattern.cpp @@ -94,7 +94,7 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p vector captures; vector bases; if (board.isDots()) { - board.calculateOneMoveCaptureAndBasePositionsForDots(true, captures, bases); + board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); } const int dxi = 1; diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index 4bfaecea8..fd4e6ebee 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -284,10 +284,10 @@ x.....x .... .... )", [](const BoardWithMoveRecords& boardWithMoveRecords) { -testAssert(12 == boardWithMoveRecords.board.numLegalMoves); +testAssert(12 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); }); - checkDotsField("Game over because of absence of legal moves", + checkDotsField("No legal moves", R"( xxxx xo.x @@ -295,8 +295,32 @@ xx.x )", [](const BoardWithMoveRecords& boardWithMoveRecords) { boardWithMoveRecords.playMove(2, 2, P_BLACK); testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); - testAssert(0 == boardWithMoveRecords.board.numLegalMoves); - }); + testAssert(0 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); + }, true); + + checkDotsField("Don't rely on legal moves number if suicide is enabled", + R"( +xxxx +xo.x +xx.x +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + boardWithMoveRecords.playMove(2, 2, P_BLACK); + testAssert(1 == boardWithMoveRecords.board.numWhiteCaptures); + // However, if suicide is disallowed, the value can be easily calculated iteratively + // So, it should be initialized by max number of moves + testAssert(12 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); +}, false); + + checkDotsField("Suicidal move is also legal move", +R"( +xxx +x.x +xxx +)", [](const BoardWithMoveRecords& boardWithMoveRecords) { + testAssert(1 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); + boardWithMoveRecords.playMove(1, 1, P_WHITE); + testAssert(0 == boardWithMoveRecords.board.numLegalMovesIfSuiAllowed); +}, true); } void Tests::runDotsGroundingTests() { diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index ae3cfa2d7..1847171cd 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -277,7 +277,7 @@ void runDotsStressTestsInternal( } if (dotsGame && suicideAllowed && lastLoc != Board::PASS_LOC) { - testAssert(0 == board.numLegalMoves); + testAssert(0 == board.numLegalMovesIfSuiAllowed); } movesCount += currentGameMovesCount; @@ -326,7 +326,7 @@ void Tests::runDotsStressTests() { } } testAssert((board.x_size - 2) * (board.y_size - 2) == board.numWhiteCaptures); - testAssert(0 == board.numLegalMoves); + testAssert(0 == board.numLegalMovesIfSuiAllowed); runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS, false, false, 0.0f, true, 0.8f, 1.0f, true); runDotsStressTestsInternal(39, 32, 3000, true, Rules::START_POS_CROSS_4, true, true, 0.5f, false, 0.8f, 1.0f, true); From 43477ad8fb730a2ba284c50c6674f8662a6a39c5 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 16:42:02 +0100 Subject: [PATCH 27/42] [GTP Support for basic GTP commands for Dots game Commands for field setup (size, rules), handicap placement fix #4 Dottify the engine: * Use numbers instead of latin numbers in field output * Player1, Player2 instead of Black and White * Tidy up default for Dots game (39*32 default size) --- cpp/book/book.cpp | 4 +- cpp/command/contribute.cpp | 2 +- cpp/command/demoplay.cpp | 4 +- cpp/command/genbook.cpp | 4 +- cpp/command/gtp.cpp | 201 +++++++++++------------ cpp/command/startposes.cpp | 4 +- cpp/core/global.cpp | 70 +++++++++ cpp/core/global.h | 6 + cpp/dataio/sgf.cpp | 18 +-- cpp/game/board.cpp | 271 ++++++++++++++++---------------- cpp/game/board.h | 12 +- cpp/game/boardhistory.cpp | 62 ++++---- cpp/game/rules.cpp | 51 ++++-- cpp/game/rules.h | 2 +- cpp/program/play.cpp | 4 +- cpp/program/playutils.cpp | 28 ++-- cpp/program/playutils.h | 13 +- cpp/program/setup.cpp | 116 +++++++------- cpp/search/searchresults.cpp | 2 +- cpp/tests/testrules.cpp | 2 +- cpp/tests/testsgf.cpp | 60 +++---- cpp/tests/testtrainingwrite.cpp | 6 +- 22 files changed, 528 insertions(+), 414 deletions(-) diff --git a/cpp/book/book.cpp b/cpp/book/book.cpp index a51ad8116..668f08a1b 100644 --- a/cpp/book/book.cpp +++ b/cpp/book/book.cpp @@ -2869,7 +2869,7 @@ void Book::saveToFile(const string& fileName) const { paramsDump["version"] = bookVersion; paramsDump["initialBoard"] = Board::toJson(initialBoard); paramsDump["initialRules"] = initialRules.toJson(); - paramsDump["initialPla"] = PlayerIO::playerToString(initialPla); + paramsDump["initialPla"] = PlayerIO::playerToString(initialPla, initialRules.isDots); paramsDump["repBound"] = repBound; paramsDump["errorFactor"] = params.errorFactor; paramsDump["costPerMove"] = params.costPerMove; @@ -2939,7 +2939,7 @@ void Book::saveToFile(const string& fileName) const { } else { nodeData["hash"] = node->hash.toString(); - nodeData["pla"] = PlayerIO::playerToString(node->pla); + nodeData["pla"] = PlayerIO::playerToString(node->pla, initialRules.isDots); nodeData["symmetries"] = node->symmetries; nodeData["winLossValue"] = node->thisValuesNotInBook.winLossValue; nodeData["scoreMean"] = node->thisValuesNotInBook.scoreMean; diff --git a/cpp/command/contribute.cpp b/cpp/command/contribute.cpp index 7cd5a0a96..ec0c3a9a0 100644 --- a/cpp/command/contribute.cpp +++ b/cpp/command/contribute.cpp @@ -218,7 +218,7 @@ static void runAndUploadSingleGame( out << "Match: " << botSpecB.botName << " (black) vs " << botSpecW.botName << " (white)" << "\n"; } out << "Rules: " << hist.rules.toJsonString() << "\n"; - out << "Player: " << PlayerIO::playerToString(pla) << "\n"; + out << "Player: " << PlayerIO::playerToString(pla,hist.rules.isDots) << "\n"; out << "Move: " << Location::toString(moveLoc,board) << "\n"; out << "Num Visits: " << search->getRootVisits() << "\n"; if(winLossHist.size() > 0) diff --git a/cpp/command/demoplay.cpp b/cpp/command/demoplay.cpp index 987035d53..daa4fb049 100644 --- a/cpp/command/demoplay.cpp +++ b/cpp/command/demoplay.cpp @@ -34,7 +34,7 @@ static void writeLine( cout << nnYLen << " "; cout << baseHist.rules.komi << " "; if(baseHist.isGameFinished) { - cout << PlayerIO::playerToString(baseHist.winner) << " "; + cout << PlayerIO::playerToString(baseHist.winner, baseHist.rules.isDots) << " "; cout << baseHist.isResignation << " "; cout << baseHist.finalWhiteMinusBlackScore << " "; } @@ -470,7 +470,7 @@ int MainCmds::demoplay(const vector& args) { ostringstream sout; sout << "genmove null location or illegal move!?!" << "\n"; sout << bot->getRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(pla) << "\n"; + sout << "Pla: " << PlayerIO::playerToString(pla, bot->getRootBoard().isDots()) << "\n"; sout << "MoveLoc: " << Location::toString(moveLoc,bot->getRootBoard()) << "\n"; logger.write(sout.str()); cerr << sout.str() << endl; diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index e1eb5c630..d49b75faa 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -450,8 +450,8 @@ int MainCmds::genbook(const vector& args) { if(bonusInitialPla != book->initialPla) throw StringError( "Book initial player and initial player in bonus sgf file do not match\n" + - PlayerIO::playerToString(book->initialPla) + " book \n" + - PlayerIO::playerToString(bonusInitialPla) + " bonus" + PlayerIO::playerToString(book->initialPla,book->initialRules.isDots) + " book \n" + + PlayerIO::playerToString(bonusInitialPla,book->initialRules.isDots) + " bonus" ); } diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 138e17473..e52443269 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -457,9 +457,10 @@ struct GTPEngine { //Specify -1 for the sizes for a default void setOrResetBoardSize(ConfigParser& cfg, Logger& logger, Rand& seedRand, int boardXSize, int boardYSize, bool loggingToStderr) { bool wasDefault = false; + bool isDots = cfg.getBoolOrDefault(DOTS_KEY, false); if(boardXSize == -1 || boardYSize == -1) { - boardXSize = Board::DEFAULT_LEN_X; - boardYSize = Board::DEFAULT_LEN_Y; + boardXSize = isDots ? Board::DEFAULT_LEN_X_DOTS : Board::DEFAULT_LEN_X; + boardYSize = isDots ? Board::DEFAULT_LEN_Y_DOTS : Board::DEFAULT_LEN_Y; wasDefault = true; } @@ -555,9 +556,11 @@ struct GTPEngine { isGenmoveParams = true; Board board(boardXSize,boardYSize,currentRules); - Player pla = P_BLACK; + board.setStartPos(seedRand); + constexpr Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); - vector newMoveHistory; + hist.setInitialTurnNumber(board.numStonesOnBoard()); + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); } @@ -569,7 +572,7 @@ struct GTPEngine { bot->setCopyOfExternalPatternBonusTable(patternBonusTable); } - void setPositionAndRules(Player pla, const Board& board, const BoardHistory& h, const Board& newInitialBoard, Player newInitialPla, const vector newMoveHistory) { + void setPositionAndRules(Player pla, const Board& board, const BoardHistory& h, const Board& newInitialBoard, Player newInitialPla, const vector& newMoveHistory) { BoardHistory hist(h); //Ensure we always have this value correct hist.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); @@ -588,6 +591,7 @@ struct GTPEngine { int newXSize = bot->getRootBoard().x_size; int newYSize = bot->getRootBoard().y_size; Board board(newXSize,newYSize,currentRules); + board.setStartPos(gtpRand); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); vector newMoveHistory; @@ -597,24 +601,23 @@ struct GTPEngine { bool setPosition(const vector& initialStones) { assert(bot->getRootHist().rules == currentRules); - int newXSize = bot->getRootBoard().x_size; - int newYSize = bot->getRootBoard().y_size; + const int newXSize = bot->getRootBoard().x_size; + const int newYSize = bot->getRootBoard().y_size; Board board(newXSize,newYSize,currentRules); - bool suc = board.setStonesFailIfNoLibs(initialStones); - if(!suc) - return false; + board.setStartPos(gtpRand); + if(!board.setStonesFailIfNoLibs(initialStones)) return false; //Sanity check - for(int i = 0; i newMoveHistory; + hist.setInitialTurnNumber(board.numStonesOnBoard()); // Heuristic to guess at what turn this is + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); return true; @@ -1166,8 +1169,9 @@ struct GTPEngine { response = "genmove returned null location or illegal move"; ostringstream sout; sout << "genmove null location or illegal move!?!" << "\n"; - sout << search->getRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(pla) << "\n"; + const auto rootBoard = search->getRootBoard(); + sout << rootBoard << "\n"; + sout << "Pla: " << PlayerIO::playerToString(pla,rootBoard.isDots()) << "\n"; sout << "MoveLoc: " << Location::toString(moveLoc,search->getRootBoard()) << "\n"; logger.write(sout.str()); genmoveTimeSum += genmoveTimer.getSeconds(); @@ -1289,24 +1293,28 @@ struct GTPEngine { if(args.analyzing) { response = "play " + response; } - - return; } - void clearCache() { + void clearCache() const { bot->clearSearch(); bot->clearEvalCache(); nnEval->clearCache(); - if(humanEval != NULL) + if(humanEval != nullptr) humanEval->clearCache(); } + // TODO: Fix handicap placement for Dots game void placeFixedHandicap(int n, string& response, bool& responseIsError) { - int xSize = bot->getRootBoard().x_size; - int ySize = bot->getRootBoard().y_size; + const int xSize = bot->getRootBoard().x_size; + const int ySize = bot->getRootBoard().y_size; Board board(xSize,ySize,currentRules); + board.setStartPos(gtpRand); + vector handicapLocs; try { - PlayUtils::placeFixedHandicap(board,n); + handicapLocs = PlayUtils::generateFixedHandicap(board, n); + for (const auto loc : handicapLocs) { + board.setStoneFailIfNoLibs(loc, P_BLACK); + } } catch(const StringError& e) { responseIsError = true; @@ -1325,18 +1333,13 @@ struct GTPEngine { pla = P_WHITE; response = ""; - for(int y = 0; y newMoveHistory; + const vector newMoveHistory; setPositionAndRules(pla,board,hist,board,pla,newMoveHistory); clearStatsForNewGame(); } @@ -1344,9 +1347,9 @@ struct GTPEngine { void placeFreeHandicap(int n, string& response, bool& responseIsError, Rand& rand) { stopAndWait(); - //If asked to place more, we just go ahead and only place up to 30, or a quarter of the board - int xSize = bot->getRootBoard().x_size; - int ySize = bot->getRootBoard().y_size; + // If asked to place more, we just go ahead and only place up to 30, or a quarter of the board + const int xSize = bot->getRootBoard().x_size; + const int ySize = bot->getRootBoard().y_size; int maxHandicap = xSize*ySize / 4; if(maxHandicap > 30) maxHandicap = 30; @@ -1356,10 +1359,11 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); Board board(xSize,ySize,currentRules); + board.setStartPos(rand); Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); double extraBlackTemperature = 0.25; - PlayUtils::playExtraBlack(bot->getSearchStopAndWait(), n, board, hist, extraBlackTemperature, rand); + const auto& handicapLocs = PlayUtils::playExtraBlack(bot->getSearchStopAndWait(), n, board, hist, extraBlackTemperature, rand); //Also switch the initial player, expecting white should be next. hist.clear(board,P_WHITE,currentRules,0); hist.setAssumeMultipleStartingBlackMovesAreHandicap(assumeMultipleStartingBlackMovesAreHandicap); @@ -1367,13 +1371,8 @@ struct GTPEngine { pla = P_WHITE; response = ""; - for(int y = 0; y& args) { const double initialDelayMoveScale = cfg.contains("delayMoveScale") ? cfg.getDouble("delayMoveScale",0.0,10000.0) : 0.0; const double initialDelayMoveMax = cfg.contains("delayMoveMax") ? cfg.getDouble("delayMoveMax",0.0,1000000.0) : 1000000.0; + int defaultBoardXSize = -1; int defaultBoardYSize = -1; Setup::loadDefaultBoardXYSize(cfg,logger,defaultBoardXSize,defaultBoardYSize); @@ -2225,7 +2225,7 @@ int MainCmds::gtp(const vector& args) { } else if(command == "name") { - response = "KataGo"; + response = "KataGoDots"; } else if(command == "version") { @@ -2967,7 +2967,7 @@ int MainCmds::gtp(const vector& args) { response += "could not parse vertex: '" + pieces[i+1] + "'"; break; } - initialStones.push_back(Move(loc,pla)); + initialStones.emplace_back(loc,pla); } if(!responseIsError) { maybeSaveAvoidPatterns(false); @@ -3134,7 +3134,7 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "Number of handicap stones less than 2: '" + pieces[0] + "'"; } - else if(!engine->bot->getRootBoard().isEmpty()) { + else if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } @@ -3158,7 +3158,7 @@ int MainCmds::gtp(const vector& args) { responseIsError = true; response = "Number of handicap stones less than 2: '" + pieces[0] + "'"; } - else if(!engine->bot->getRootBoard().isEmpty()) { + else if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } @@ -3169,7 +3169,7 @@ int MainCmds::gtp(const vector& args) { } else if(command == "set_free_handicap") { - if(!engine->bot->getRootBoard().isEmpty()) { + if(!engine->bot->getRootBoard().isStartPos()) { responseIsError = true; response = "Board is not empty"; } @@ -3179,14 +3179,15 @@ int MainCmds::gtp(const vector& args) { int xSize = rootBoard->x_size; int ySize = rootBoard->y_size; Board board(xSize,ySize,rootBoard->rules); - for(int i = 0; i& args) { response = "Expected one or two arguments for loadsgf but got '" + Global::concat(pieces," ") + "'"; } else { - string filename = pieces[0]; + const string& filename = pieces[0]; bool parseFailed = false; bool moveNumberSpecified = false; int moveNumber = 0; @@ -3320,7 +3321,7 @@ int MainCmds::gtp(const vector& args) { engine->getCurrentRules(), //Use current rules as default [&logger](const string& msg) { logger.write(msg); cerr << msg << endl; } ); - if(engine->nnEval != NULL) { + if(engine->nnEval != nullptr) { bool rulesWereSupported; Rules supportedRules = engine->nnEval->getSupportedRules(sgfRules,rulesWereSupported); if(!rulesWereSupported) { @@ -3580,59 +3581,65 @@ int MainCmds::gtp(const vector& args) { } if(parsed) { - engine->stopAndWait(); - - int boardSizeX = engine->bot->getRootBoard().x_size; - int boardSizeY = engine->bot->getRootBoard().y_size; - if(boardSizeX != boardSizeY) { + const auto& rootBoard = engine->bot->getRootBoard(); + if (rootBoard.rules.isDots) { responseIsError = true; - response = - "Current board size is " + Global::intToString(boardSizeX) + "x" + Global::intToString(boardSizeY) + - ", no built-in benchmarks for rectangular boards"; - } - else { - std::unique_ptr sgf = nullptr; - try { - string sgfData = TestCommon::getBenchmarkSGFData(boardSizeX); - sgf = CompactSgf::parse(sgfData); - } - catch(const StringError& e) { + response = "Not yet supported for Dots"; + } else { + engine->stopAndWait(); + int boardSizeX = rootBoard.x_size; + int boardSizeY = rootBoard.y_size; + if(boardSizeX != boardSizeY) { responseIsError = true; - response = e.what(); + response = + "Current board size is " + Global::intToString(boardSizeX) + "x" + Global::intToString(boardSizeY) + + ", no built-in benchmarks for rectangular boards"; } - if(sgf != nullptr) { - const PlayUtils::BenchmarkResults* baseline = NULL; - const double secondsPerGameMove = 1.0; - const bool printElo = false; - SearchParams params = engine->getGenmoveParams(); - params.maxTime = 1.0e20; - params.maxPlayouts = ((int64_t)1) << 50; - params.maxVisits = numVisits; - //Make sure the "equals" for GTP is printed out prior to the benchmark line - printGTPResponseHeader(); - + else { + std::unique_ptr sgf = nullptr; try { - PlayUtils::BenchmarkResults results = PlayUtils::benchmarkSearchOnPositionsAndPrint( - params, - *sgf, - 10, - engine->nnEval, - baseline, - secondsPerGameMove, - printElo - ); - (void)results; + string sgfData = TestCommon::getBenchmarkSGFData(boardSizeX); + sgf = CompactSgf::parse(sgfData); } catch(const StringError& e) { responseIsError = true; response = e.what(); - sgf = nullptr; } if(sgf != nullptr) { - //Act of benchmarking will write to stdout with a newline at the end, so we just need one more newline ourselves - //to complete GTP protocol. - suppressResponse = true; - cout << endl; + const PlayUtils::BenchmarkResults* baseline = nullptr; + const double secondsPerGameMove = 1.0; + const bool printElo = false; + SearchParams params = engine->getGenmoveParams(); + params.maxTime = 1.0e20; + params.maxPlayouts = ((int64_t)1) << 50; + params.maxVisits = numVisits; + //Make sure the "equals" for GTP is printed out prior to the benchmark line + printGTPResponseHeader(); + + try { + PlayUtils::BenchmarkResults results = PlayUtils::benchmarkSearchOnPositionsAndPrint( + params, + *sgf, + 10, + engine->nnEval, + baseline, + secondsPerGameMove, + printElo + ); + (void)results; + } + catch(const StringError& e) { + responseIsError = true; + response = e.what(); + + sgf = nullptr; + } + if(sgf != nullptr) { + //Act of benchmarking will write to stdout with a newline at the end, so we just need one more newline ourselves + //to complete GTP protocol. + suppressResponse = true; + cout << endl; + } } } } diff --git a/cpp/command/startposes.cpp b/cpp/command/startposes.cpp index 1af2904f0..4e739492e 100644 --- a/cpp/command/startposes.cpp +++ b/cpp/command/startposes.cpp @@ -2280,10 +2280,10 @@ int MainCmds::viewstartposes(const vector& args) { if(bot != NULL || !checkLegality) { cout << "StartPos: " << s << "/" << startPoses.size() << "\n"; - cout << "Next pla: " << PlayerIO::playerToString(pla) << "\n"; + cout << "Next pla: " << PlayerIO::playerToString(pla, board.isDots()) << "\n"; cout << "Weight: " << startPos.weight << "\n"; cout << "TrainingWeight: " << startPos.trainingWeight << "\n"; - cout << "StartPosInitialNextPla: " << PlayerIO::playerToString(startPos.nextPla) << "\n"; + cout << "StartPosInitialNextPla: " << PlayerIO::playerToString(startPos.nextPla, board.isDots()) << "\n"; cout << "StartPosMoves: "; for(int i = 0; i<(int)startPos.moves.size(); i++) cout << (startPos.moves[i].pla == P_WHITE ? "w" : "b") << Location::toString(startPos.moves[i].loc,board) << " "; diff --git a/cpp/core/global.cpp b/cpp/core/global.cpp index a04d2eda6..a1e74635b 100644 --- a/cpp/core/global.cpp +++ b/cpp/core/global.cpp @@ -16,6 +16,8 @@ #include #include +#include "test.h" + using namespace std; //ERRORS---------------------------------- @@ -117,6 +119,41 @@ string Global::uint64ToHexString(uint64_t x) return s; } +static const std::string CoordAlphabet = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; +static const int CoordAlphabetLength = CoordAlphabet.length(); + +std::string Global::intToCoord(const int x) { + assert(x >= 0); + + int v = x + 1; + std::string out; + while (v > 0) { + out.push_back(CoordAlphabet[(v - 1) % CoordAlphabetLength]); + v = (v - 1) / CoordAlphabetLength; + } + std::reverse(out.begin(), out.end()); + return out; +} + +bool Global::tryCoordToInt(const std::string& coord, int& x) { + if (coord.empty()) return false; + + long long v = 0; + for (const char c : coord) { + const auto pos = CoordAlphabet.find(std::toupper(c)); + if (pos == std::string::npos) return false; + v = v * CoordAlphabetLength + (static_cast(pos) + 1); + if (v > static_cast(std::numeric_limits::max()) + 1LL) { + return false; + } + } + + // convert from 1-based bijective value back to 0-based index + x = static_cast(v - 1); + assert(x >= 0); + return true; +} + string Global::sizeToString(size_t x) { stringstream ss; @@ -714,3 +751,36 @@ bool Global::isEqual(const float f1, const float f2) { bool Global::isZero(const float f) { return std::fabs(f) <= FLOAT_EPS; } + +void Global::runTests() { + testAssert(isEqual(0.0f, 0.0f)); + testAssert(isEqual(FLOAT_EPS, FLOAT_EPS)); + testAssert(isZero(FLOAT_EPS)); + testAssert(isZero(-FLOAT_EPS)); + testAssert(!isEqual(42, 42 + 0.0001f)); + testAssert(!isEqual(-FLOAT_EPS, FLOAT_EPS)); + + auto checkCoordConversion = [](const string& coord, const int value) { + testAssert(toUpper(coord) == intToCoord(value)); + int newValue; + testAssert(tryCoordToInt(coord, newValue)); + testAssert(value == newValue); + }; + + checkCoordConversion("A", 0); + checkCoordConversion("B", 1); + checkCoordConversion("Z", 24); + checkCoordConversion("AA", 25); + checkCoordConversion("AB", 26); + checkCoordConversion("AZ", 49); + checkCoordConversion("az", 49); + checkCoordConversion("BA", 50); + checkCoordConversion("ZZ", 649); + checkCoordConversion("AAA", 650); + + int newValue; + testAssert(!tryCoordToInt("", newValue)); + testAssert(!tryCoordToInt("!", newValue)); + testAssert(tryCoordToInt("ABCDEFG", newValue)); + testAssert(!tryCoordToInt("ABCDEFGH", newValue)); // Too big number -> overflow +} diff --git a/cpp/core/global.h b/cpp/core/global.h index ad302a6b4..bb65afc5f 100644 --- a/cpp/core/global.h +++ b/cpp/core/global.h @@ -76,6 +76,10 @@ namespace Global std::string uint32ToHexString(uint32_t x); std::string uint64ToHexString(uint64_t x); std::string sizeToString(size_t x); + // Convert numbers to letters that are used on real Go board + // It matches the English alphabet except missing I that is ignored to avoid confusion with J + std::string intToCoord(int x); + bool tryCoordToInt(const std::string& coord, int& x); //String to conversions using the standard library parsing int stringToInt(const std::string& str); @@ -163,6 +167,8 @@ namespace Global constexpr float FLOAT_EPS = std::numeric_limits::epsilon(); bool isEqual(float f1, float f2); bool isZero(float f); + + void runTests(); } // namespace Global struct StringError : public std::exception { diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 53de52477..ba1d174dc 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1506,7 +1506,10 @@ static std::unique_ptr maybeParseSgf(const string& str, size_t& pos) { && handicap <= 9 ) { Board board(rootSgf->getRulesOrFail()); - PlayUtils::placeFixedHandicap(board, handicap); + const auto& handicapLocs = PlayUtils::generateFixedHandicap(board, handicap); + for (const auto loc : handicapLocs) { + board.setStoneFailIfNoLibs(loc, P_BLACK); + } // Older fox sgfs used handicaps with side stones on the north and south rather than east and west if(handicap == 6 || handicap == 7) { if(rootSgf->hasRootProperty("DT")) { @@ -1523,15 +1526,10 @@ static std::unique_ptr maybeParseSgf(const string& str, size_t& pos) { } } - for(int y = 0; yaddRootProperty("AB",out.str()); - } - } + for (const auto loc : handicapLocs) { + ostringstream out; + writeSgfLoc(out, loc, board.x_size, board.y_size); + rootSgf->addRootProperty("AB",out.str()); } } return rootSgf; diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 399fba149..f0a43aaa8 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -120,7 +120,7 @@ Board::Base::Base(Player newPla, Board::Board() : Board(Rules::DEFAULT_GO) {} Board::Board(const Rules& rules) { - init(DEFAULT_LEN_X, DEFAULT_LEN_Y, rules); + init(rules.isDots ? DEFAULT_LEN_X_DOTS : DEFAULT_LEN_X, rules.isDots ? DEFAULT_LEN_Y_DOTS : DEFAULT_LEN_Y, rules); } Board::Board(const int x, const int y, const Rules& rules) { @@ -285,6 +285,11 @@ Color Board::getColor(const Loc loc) const { return static_cast(colors[loc] & ACTIVE_MASK); } +Color Board::getPlacedColor(const Loc loc) const { + const State state = getState(loc); + return rules.isDots ? getPlacedDotColor(state) : state; +} + double Board::sqrtBoardArea() const { if (x_size == y_size) { return x_size; @@ -686,42 +691,47 @@ bool Board::isNonPassAliveSelfConnection(Loc loc, Player pla, Color* passAliveAr return false; } -bool Board::isEmpty() const { - for(int y = 0; y < y_size; y++) { - for(int x = 0; x < x_size; x++) { - Loc loc = Location::getLoc(x,y,x_size); - if(colors[loc] != C_EMPTY) - return false; - } - } - return true; +bool Board::isStartPos() const { + int startBoardNumBlackStones, startBoardNumWhiteStones; + numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, false); + return startBoardNumBlackStones == 0 && startBoardNumWhiteStones == 0; } int Board::numStonesOnBoard() const { - int num = 0; - for(int y = 0; y < y_size; y++) { - for(int x = 0; x < x_size; x++) { - const Loc loc = Location::getLoc(x,y,x_size); - if(const Color color = rules.isDots ? getPlacedDotColor(getState(loc)) : colors[loc]; - color == C_BLACK || color == C_WHITE) { - num += 1; - } - } - } - return num; + int startBoardNumBlackStones, startBoardNumWhiteStones; + numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, true); + return startBoardNumBlackStones + startBoardNumWhiteStones; } int Board::numPlaStonesOnBoard(Player pla) const { - int num = 0; + int startBoardNumBlackStones, startBoardNumWhiteStones; + numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, true); + return pla == C_BLACK ? startBoardNumBlackStones : startBoardNumWhiteStones; +} + +void Board::numStartBlackWhiteStones(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, + const bool includeStartLocs) const { + startBoardNumBlackStones = 0; + startBoardNumWhiteStones = 0; + + // Ignore start pos moves that are generated according to rules + set startLocs; + for(auto move: start_pos_moves) { + startLocs.insert(move.loc); + } + for(int y = 0; y < y_size; y++) { for(int x = 0; x < x_size; x++) { - const Loc loc = Location::getLoc(x,y,x_size); - if(const Color color = rules.isDots ? getPlacedDotColor(getState(loc)) : colors[loc]; color == pla) { - num += 1; + if(const Loc loc = Location::getLoc(x, y, x_size)) { + if (includeStartLocs || startLocs.count(loc) == 0) { + if(const Color color = getPlacedColor(loc); color == C_BLACK) + startBoardNumBlackStones += 1; + else if(color == C_WHITE) + startBoardNumWhiteStones += 1; + } } } } - return num; } bool Board::setStone(Loc loc, Color color) @@ -750,7 +760,7 @@ bool Board::setStone(Loc loc, Color color) } bool Board::setStoneFailIfNoLibs(Loc loc, Color color, const bool startPos) { - Color colorAtLoc = getColor(loc); + const Color colorAtLoc = getColor(loc); if(loc < 0 || loc >= MAX_ARR_SIZE || colorAtLoc == C_WALL) return false; if(color != C_BLACK && color != C_WHITE && color != C_EMPTY) @@ -2543,13 +2553,28 @@ char PlayerIO::colorToChar(Color c) } } -string PlayerIO::playerToString(Color c) +char PlayerIO::stateToChar(const State s, const bool isDots) { + if (!isDots) return colorToChar(s); + + const Color activeColor = getActiveColor(s); + const Color placedColor = getPlacedDotColor(s); + const bool captured = activeColor != placedColor; + + switch (placedColor) { + case C_BLACK: return captured ? 'x' : 'X'; + case C_WHITE: return captured ? 'o' : 'O'; + case C_EMPTY: return captured ? '\'' : '.'; + default: return '#'; + } +} + +string PlayerIO::playerToString(const Color c, const bool isDots) { switch(c) { - case C_BLACK: return "Black"; - case C_WHITE: return "White"; + case C_BLACK: return !isDots ? "Black" : "Player1"; + case C_WHITE: return !isDots ? "White" : "Player2"; case C_EMPTY: return "Empty"; - default: return "Wall"; + default: return "Wall"; } } @@ -2559,17 +2584,15 @@ string PlayerIO::playerToStringShort(Color c) case C_BLACK: return "B"; case C_WHITE: return "W"; case C_EMPTY: return "E"; - default: return ""; + default: return "Wall"; } } bool PlayerIO::tryParsePlayer(const string& s, Player& pla) { - string str = Global::toLower(s); - if(str == "black" || str == "b") { + if(const string str = Global::toLower(s); str == "black" || str == "b" || str == "blue" || str == "p1") { pla = P_BLACK; return true; - } - else if(str == "white" || str == "w") { + } else if(str == "white" || str == "w" || str == "red" || str == "r" || str == "p2") { pla = P_WHITE; return true; } @@ -2578,119 +2601,102 @@ bool PlayerIO::tryParsePlayer(const string& s, Player& pla) { Player PlayerIO::parsePlayer(const string& s) { Player pla = C_EMPTY; - bool suc = tryParsePlayer(s,pla); - if(!suc) - throw StringError("Could not parse player: " + s); + if (!tryParsePlayer(s, pla)) throw StringError("Could not parse player: " + s); return pla; } -string Location::toStringMach(Loc loc, int x_size, bool isDots) -{ +string Location::toStringMach(const Loc loc, const int x_size, const bool isDots) { if(loc == Board::PASS_LOC) return isDots ? "ground" : "pass"; if(loc == Board::NULL_LOC) return string("null"); char buf[128]; - sprintf(buf,"(%d,%d)",getX(loc,x_size),getY(loc,x_size)); + sprintf(buf, "(%d,%d)", getX(loc, x_size), getY(loc, x_size)); return string(buf); } -string Location::toString(Loc loc, int x_size, int y_size, bool isDots) -{ - if(x_size > 25*25) - return toStringMach(loc, x_size, isDots); +string Location::toString(const Loc loc, int x_size, int y_size, bool isDots) { if(loc == Board::PASS_LOC) return isDots ? "ground" : "pass"; if(loc == Board::NULL_LOC) return string("null"); - const char* xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; - int x = getX(loc,x_size); - int y = getY(loc,x_size); + const int x = getX(loc, x_size); + const int y = getY(loc, x_size); if(x >= x_size || x < 0 || y < 0 || y >= y_size) return toStringMach(loc, x_size, isDots); - char buf[128]; - if(x <= 24) - sprintf(buf,"%c%d",xChar[x],y_size-y); - else - sprintf(buf,"%c%c%d",xChar[x/25-1],xChar[x%25],y_size-y); - return string(buf); + return (isDots ? (std::to_string(x + 1) + "-") : Global::intToCoord(x)) + std::to_string(y_size - y); } -string Location::toString(Loc loc, const Board& b) { +string Location::toString(const Loc loc, const Board& b) { return toString(loc, b.x_size, b.y_size, b.rules.isDots); } -string Location::toStringMach(Loc loc, const Board& b) { +string Location::toStringMach(const Loc loc, const Board& b) { return toStringMach(loc, b.x_size, b.isDots()); } -static bool tryParseLetterCoordinate(char c, int& x) { - if(c >= 'A' && c <= 'H') - x = c-'A'; - else if(c >= 'a' && c <= 'h') - x = c-'a'; - else if(c >= 'J' && c <= 'Z') - x = c-'A'-1; - else if(c >= 'j' && c <= 'z') - x = c-'a'-1; - else - return false; - return true; -} - -bool Location::tryOfString(const string& str, int x_size, int y_size, Loc& result) { +bool Location::tryOfString(const string& str, const int x_size, const int y_size, Loc& result) { string s = Global::trim(str); if(s.length() < 2) return false; - if(Global::isEqualCaseInsensitive(s,string("pass")) || Global::isEqualCaseInsensitive(s,string("pss")) || - Global::isEqualCaseInsensitive(s,string("ground"))) { + if( + Global::isEqualCaseInsensitive(s, string("pass")) || Global::isEqualCaseInsensitive(s, string("pss")) || + Global::isEqualCaseInsensitive(s, string("ground"))) { result = Board::PASS_LOC; return true; } if(s[0] == '(') { - if(s[s.length()-1] != ')') + if(s[s.length() - 1] != ')') return false; - s = s.substr(1,s.length()-2); - vector pieces = Global::split(s,','); - if(pieces.size() != 2) + + s = s.substr(1, s.length() - 2); + const int commaIndex = s.find(','); + if(commaIndex == std::string::npos) return false; + int x; + if(!Global::tryStringToInt(s.substr(0, commaIndex), x)) + return false; + int y; - bool sucX = Global::tryStringToInt(pieces[0],x); - bool sucY = Global::tryStringToInt(pieces[1],y); - if(!sucX || !sucY) + if(!Global::tryStringToInt(s.substr(commaIndex + 1), y)) return false; - result = Location::getLoc(x,y,x_size); + + result = getLoc(x, y, x_size); return true; } - else { + + if(const auto dashPos = s.find('-'); dashPos != std::string::npos) { int x; - if(!tryParseLetterCoordinate(s[0],x)) + if(!Global::tryStringToInt(s.substr(0, dashPos), x)) return false; - - //Extended format - if((s[1] >= 'A' && s[1] <= 'Z') || (s[1] >= 'a' && s[1] <= 'z')) { - int x1; - if(!tryParseLetterCoordinate(s[1],x1)) - return false; - x = (x+1) * 25 + x1; - s = s.substr(2,s.length()-2); - } - else { - s = s.substr(1,s.length()-1); - } - int y; - bool sucY = Global::tryStringToInt(s,y); - if(!sucY) + if(!Global::tryStringToInt(s.substr(dashPos + 1), y)) return false; - y = y_size - y; - if(x < 0 || y < 0 || x >= x_size || y >= y_size) - return false; - result = Location::getLoc(x,y,x_size); + + result = getLoc(x - 1, y_size - y, x_size); return true; } + + const auto firstDigitIndex = s.find_first_of("0123456789"); + if(firstDigitIndex == std::string::npos) + return false; + + int x; + if(!Global::tryCoordToInt(s.substr(0, firstDigitIndex), x)) + return false; + + int y; + if(!Global::tryStringToInt(s.substr(firstDigitIndex), y)) + return false; + y = y_size - y; + + if(x < 0 || y < 0 || x >= x_size || y >= y_size) + return false; + + result = getLoc(x, y, x_size); + return true; } bool Location::tryOfStringAllowNull(const string& str, int x_size, int y_size, Loc& result) { @@ -2744,67 +2750,66 @@ vector Location::parseSequence(const string& str, const Board& board) { return locs; } -void Board::printBoard(ostream& out, const Board& board, const Loc markLoc, const vector* hist, bool printHash) { +void Board::printBoard( + ostream& out, + const Board& board, + const Loc markLoc, + const vector* hist, + const bool printHash) { if(hist != nullptr) out << "MoveNum: " << hist->size() << " "; - if (printHash) { - out << "HASH: " << board.pos_hash << "\n"; + if(printHash) { + out << "HASH: " << board.pos_hash; + } + if(hist != nullptr || printHash) { + out << endl; } - bool showCoords = board.isDots() || (board.x_size <= 50 && board.y_size <= 50); + const bool showCoords = board.isDots() || (board.x_size <= 50 && board.y_size <= 50); if(showCoords) { - if (board.isDots()) { + if(board.isDots()) { out << " "; for(int x = 0; x < board.x_size; x++) { - out << std::left << std::setw(2) << (x + 1) << ' '; + out << std::left << std::setw(2) << std::to_string(x + 1) << ' '; } } else { - auto xChar = "ABCDEFGHJKLMNOPQRSTUVWXYZ"; out << " "; for(int x = 0; x < board.x_size; x++) { - if(x <= 24) { - out << " "; - out << xChar[x]; - } - else { - out << "A" << xChar[x-25]; - } + out << std::right << std::setw(2) << Global::intToCoord(x); } } out << "\n"; } - for(int y = 0; y < board.y_size; y++) - { + for(int y = 0; y < board.y_size; y++) { if(showCoords) { - out << std::right << std::setw(2) << board.y_size-y << ' '; + out << std::right << std::setw(2) << board.y_size - y << ' '; } - for(int x = 0; x < board.x_size; x++) - { - Loc loc = Location::getLoc(x, y , board.x_size); - const Color color = board.getColor(loc); // TODO: probably it makes sense to implement debug printing for Dots game - char s = PlayerIO::colorToChar(color); - if(color == C_EMPTY && markLoc == loc) + for(int x = 0; x < board.x_size; x++) { + const Loc loc = Location::getLoc(x, y, board.x_size); + const State state = board.getState(loc); + const char s = PlayerIO::stateToChar(state, board.rules.isDots); + if(getActiveColor(state) == C_EMPTY && markLoc == loc) out << '@'; else out << s; bool histMarked = false; - if(hist != NULL) { - size_t start = hist->size() >= 3 ? hist->size()-3 : 0; - for(size_t i = 0; start+i < hist->size(); i++) { - if((*hist)[start+i].loc == loc) { - out << (1+i); + if(hist != nullptr) { + size_t start = hist->size() >= 3 ? hist->size() - 3 : 0; + for(size_t i = 0; start + i < hist->size(); i++) { + if((*hist)[start + i].loc == loc) { + out << (1 + i); histMarked = true; break; } } } - if (!histMarked && board.isDots()) { + if(!histMarked && board.isDots()) { out << ' '; } - if(x < board.x_size-1 && (!histMarked || board.isDots())) + if(x < board.x_size - 1 && (!histMarked || board.isDots())) out << ' '; } out << "\n"; diff --git a/cpp/game/board.h b/cpp/game/board.h index 3d1bc7d33..df14297f6 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -53,10 +53,12 @@ bool isTerritory(State s); //Conversions for players and colors namespace PlayerIO { char colorToChar(Color c); + char stateToChar(State s, bool isDots); std::string playerToStringShort(Player p); - std::string playerToString(Player p); + std::string playerToString(Player p, bool isDots); bool tryParsePlayer(const std::string& s, Player& pla); Player parsePlayer(const std::string& s); + char stateToChar(State s, bool isDots); } namespace Location @@ -137,6 +139,8 @@ struct Board static constexpr int MAX_LEN = std::max(MAX_LEN_X, MAX_LEN_Y); //Maximum edge length allowed for the board static constexpr int DEFAULT_LEN_X = std::min(MAX_LEN_X,19); //Default x edge length for board if unspecified static constexpr int DEFAULT_LEN_Y = std::min(MAX_LEN_Y,19); //Default y edge length for board if unspecified + static constexpr int DEFAULT_LEN_X_DOTS = std::min(MAX_LEN_X, 39); + static constexpr int DEFAULT_LEN_Y_DOTS = std::min(MAX_LEN_Y, 32); static constexpr int MAX_PLAY_SIZE = MAX_LEN_X * MAX_LEN_Y; //Maximum number of playable spaces static constexpr int MAX_ARR_SIZE = getMaxArrSize(MAX_LEN_X, MAX_LEN_Y); //Maximum size of arrays needed @@ -239,6 +243,7 @@ struct Board //Functions------------------------------------ [[nodiscard]] Color getColor(Loc loc) const; + [[nodiscard]] Color getPlacedColor(Loc loc) const; [[nodiscard]] State getState(Loc loc) const; void setState(Loc loc, State state); bool isDots() const; @@ -288,11 +293,12 @@ struct Board bool isAdjacentToChain(Loc loc, Loc chain) const; //Does this connect two pla distinct groups that are not both pass-alive and not within opponent pass-alive area either? bool isNonPassAliveSelfConnection(Loc loc, Player pla, Color* passAliveArea) const; - //Is this board empty? - bool isEmpty() const; + // Is this board empty? + bool isStartPos() const; //Count the number of stones on the board int numStonesOnBoard() const; int numPlaStonesOnBoard(Player pla) const; + void numStartBlackWhiteStones(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, bool includeStartLocs) const; //Get a hash that combines the position of the board with simple ko prohibition and a player to move. Hash128 getSitHashWithSimpleKo(Player pla) const; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 31599ae19..b772cce84 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -387,40 +387,24 @@ void BoardHistory::setInitialTurnNumber(int64_t n) { void BoardHistory::setAssumeMultipleStartingBlackMovesAreHandicap(bool b) { assumeMultipleStartingBlackMovesAreHandicap = b; - whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); + whiteHandicapBonusScore = static_cast(computeWhiteHandicapBonus()); } void BoardHistory::setOverrideNumHandicapStones(int n) { overrideNumHandicapStones = n; - whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); + whiteHandicapBonusScore = static_cast(computeWhiteHandicapBonus()); } -static int numHandicapStonesOnBoardHelper(const Board& board, const int blackNonPassTurnsToStart) { - int startBoardNumBlackStones = 0; - int startBoardNumWhiteStones = 0; - - // Ignore start pos moves that are generated according to rules - set startLocs; - for (auto move : board.start_pos_moves) { - startLocs.insert(move.loc); - } - for(int y = 0; y 0) out << "Game phase: " << encorePhase << endl; out << "Rules: " << rules.toJsonString() << endl; if(whiteHandicapBonusScore != 0) out << "Handicap bonus score: " << whiteHandicapBonusScore << endl; - out << "B stones captured: " << board.numBlackCaptures << endl; - out << "W stones captured: " << board.numWhiteCaptures << endl; + + const auto firstPlayerName = PlayerIO::playerToString(P_BLACK, isDots); + const auto secondPlayerName = PlayerIO::playerToString(P_WHITE, isDots); + if (!isDots) { + out << firstPlayerName << " stones captured: " << board.numBlackCaptures << endl; + out << secondPlayerName << " stones captured: " << board.numWhiteCaptures << endl; + } else { + out << firstPlayerName << " score: " << board.numWhiteCaptures << endl; + out << secondPlayerName << " score: " << board.numBlackCaptures << endl; + } } void BoardHistory::printDebugInfo(ostream& out, const Board& board) const { out << board << endl; - bool isDots = board.rules.isDots; + const bool isDots = board.rules.isDots; if (!isDots) { - out << "Initial pla " << PlayerIO::playerToString(initialPla) << endl; + out << "Initial pla " << PlayerIO::playerToString(initialPla, rules.isDots) << endl; out << "Encore phase " << encorePhase << endl; out << "Turns this phase " << numTurnsThisPhase << endl; out << "Approx valid turns this phase " << numApproxValidTurnsThisPhase << endl; @@ -520,21 +514,21 @@ void BoardHistory::printDebugInfo(ostream& out, const Board& board) const { assert(0.0f == whiteHandicapBonusScore); assert(!hasButton); } - out << "Presumed next pla " << PlayerIO::playerToString(presumedNextMovePla) << endl; + out << "Presumed next pla " << PlayerIO::playerToString(presumedNextMovePla, rules.isDots) << endl; if (!isDots) { out << "Past normal phase end " << isPastNormalPhaseEnd << endl; } else { assert(0 == isPastNormalPhaseEnd); } - out << "Game result " << isGameFinished << " " << PlayerIO::playerToString(winner) << " " + out << "Game result " << isGameFinished << " " << PlayerIO::playerToString(winner, rules.isDots) << " " << finalWhiteMinusBlackScore << " " << isScored << " " << isNoResult << " " << isResignation; if (isPassAliveFinished) { out << " " << isPassAliveFinished; } out << endl; out << "Last moves "; - for(int i = 0; i 0 && moveHistory[0].pla == P_WHITE) + if(initialBoard.isStartPos() && moveHistory.size() > 0 && moveHistory[0].pla == P_WHITE) return true; //Black passed exactly once or white doublemoved int numBlackPasses = 0; diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index 67532b3c2..1138b891e 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -350,20 +350,39 @@ string Rules::toJsonStringNoKomiMaybeOmitStuff() const { return toJsonHelper(true,true).dump(); } -Rules Rules::updateRules(const string& k, const string& v, Rules oldRules) { +Rules Rules::updateRules(const string& k, const string& v, const Rules& oldRules) { Rules rules = oldRules; - string key = Global::trim(k); + const string key = Global::trim(k); const string value = Global::trim(Global::toUpper(v)); - if(key == DOTS_KEY) rules.isDots = Global::stringToBool(value); - else if(key == "ko") rules.koRule = Rules::parseKoRule(value); - else if(key == "score") rules.scoringRule = Rules::parseScoringRule(value); - else if(key == "scoring") rules.scoringRule = Rules::parseScoringRule(value); - else if(key == "tax") rules.taxRule = Rules::parseTaxRule(value); - else if(key == "suicide") rules.multiStoneSuicideLegal = Global::stringToBool(value); - else if(key == "hasButton") rules.hasButton = Global::stringToBool(value); - else if(key == "whiteHandicapBonus") rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(value); - else if(key == "friendlyPassOk") rules.friendlyPassOk = Global::stringToBool(value); - else throw IOError("Unknown rules option: " + key); + + bool parsed = true; + + if (key == DOTS_KEY) rules.isDots = Global::stringToBool(value); + else if (key == "suicide") rules.multiStoneSuicideLegal = Global::stringToBool(value); + else if (key == START_POS_KEY) rules.startPos = parseStartPos(value); + else if (key == START_POS_RANDOM_KEY) rules.startPosIsRandom = Global::stringToBool(value); + else parsed = false; + + if (!parsed) { + parsed = true; + if (!rules.isDots) { + if(key == "ko") rules.koRule = parseKoRule(value); + else if(key == "score" || key == "scoring") rules.scoringRule = parseScoringRule(value); + else if(key == "tax") rules.taxRule = parseTaxRule(value); + else if(key == "hasButton") rules.hasButton = Global::stringToBool(value); + else if(key == "whiteHandicapBonus") rules.whiteHandicapBonusRule = parseWhiteHandicapBonusRule(value); + else if(key == "friendlyPassOk") rules.friendlyPassOk = Global::stringToBool(value); + else parsed = false; + } else { + if (key == DOTS_CAPTURE_EMPTY_BASE_KEY) rules.dotsCaptureEmptyBases = Global::stringToBool(value); + else parsed = false; + } + } + + if (!parsed) { + throw IOError("Unknown rules option: " + key); + } + return rules; } @@ -478,7 +497,7 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) rules.friendlyPassOk = true; rules.komi = 7.5; } - else if(sOrig.length() > 0 && sOrig[0] == '{') { + else if(!sOrig.empty() && sOrig[0] == '{') { //Default if not specified rules = Rules::getDefaultOrTrompTaylorish(isDots); bool komiSpecified = false; @@ -487,8 +506,10 @@ static Rules parseRulesHelper(const string& sOrig, bool allowKomi, bool isDots) json input = json::parse(sOrig); string s; for(json::iterator iter = input.begin(); iter != input.end(); ++iter) { - string key = iter.key(); - if (key == START_POS_KEY) + const string& key = iter.key(); + if (key == DOTS_KEY) + rules.isDots = iter.value().get(); + else if (key == START_POS_KEY) rules.startPos = Rules::parseStartPos(iter.value().get()); else if (key == START_POS_RANDOM_KEY) rules.startPosIsRandom = iter.value().get(); diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 862604468..89b744dc5 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -111,7 +111,7 @@ struct Rules { static bool tryParseRules(const std::string& sOrig, Rules& buf, bool isDots); static bool tryParseRulesWithoutKomi(const std::string& sOrig, Rules& buf, float komi, bool isDots); - static Rules updateRules(const std::string& key, const std::string& value, Rules priorRules); + static Rules updateRules(const std::string& key, const std::string& value, const Rules& oldRules); static std::vector generateStartPos(int startPos, Rand* rand, int x_size, int y_size); /** diff --git a/cpp/program/play.cpp b/cpp/program/play.cpp index 8ee628d08..db21f1ab0 100644 --- a/cpp/program/play.cpp +++ b/cpp/program/play.cpp @@ -795,12 +795,12 @@ pair MatchPairer::getMatchupPairUnsynchronized() { //---------------------------------------------------------------------------------------------------------- -static void failIllegalMove(Search* bot, Logger& logger, Board board, Loc loc) { +static void failIllegalMove(const Search* bot, Logger& logger, const Board& board, Loc loc) { ostringstream sout; sout << "Bot returned null location or illegal move!?!" << "\n"; sout << board << "\n"; sout << bot->getRootBoard() << "\n"; - sout << "Pla: " << PlayerIO::playerToString(bot->getRootPla()) << "\n"; + sout << "Pla: " << PlayerIO::playerToString(bot->getRootPla(),board.isDots()) << "\n"; sout << "Loc: " << Location::toString(loc,bot->getRootBoard()) << "\n"; logger.write(sout.str()); bot->getRootBoard().checkConsistency(); diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index fada11223..f3267c4be 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -264,17 +264,17 @@ void PlayUtils::initializeGameUsingPolicy( //Place black handicap stones, free placement -//Does NOT switch the initial player of the board history to white -void PlayUtils::playExtraBlack( +// Does NOT switch the initial player of the board history to white +vector PlayUtils::playExtraBlack( Search* bot, int numExtraBlack, Board& board, BoardHistory& hist, double temperature, - Rand& gameRand -) { + Rand& gameRand) { Player pla = P_BLACK; + std::vector handicapLocs; if(!hist.isGameFinished) { NNResultBuf buf; for(int i = 0; isetPosition(pla,board,hist); + return handicapLocs; } -void PlayUtils::placeFixedHandicap(Board& board, int n) { - int xSize = board.x_size; - int ySize = board.y_size; +vector PlayUtils::generateFixedHandicap(const Board& board, const int n) { + const int xSize = board.x_size; + const int ySize = board.y_size; if(xSize < 7 || ySize < 7) throw StringError("Board is too small for fixed handicap"); if((xSize % 2 == 0 || ySize % 2 == 0) && n > 4) @@ -312,8 +314,6 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { if(n > 9) throw StringError("Fixed handicap > 9 is not allowed"); - board = Board(xSize,ySize,board.rules); - int xCoords[3]; //Corner, corner, side int yCoords[3]; //Corner, corner, side if(xSize <= 12) { xCoords[0] = 2; xCoords[1] = xSize-3; xCoords[2] = xSize/2; } @@ -321,8 +321,10 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { if(ySize <= 12) { yCoords[0] = 2; yCoords[1] = ySize-3; yCoords[2] = ySize/2; } else { yCoords[0] = 3; yCoords[1] = ySize-4; yCoords[2] = ySize/2; } - auto s = [&](int xi, int yi) { - board.setStone(Location::getLoc(xCoords[xi],yCoords[yi],board.x_size),P_BLACK); + vector locs; + + auto s = [&](const int xi, const int yi) { + locs.push_back(Location::getLoc(xCoords[xi],yCoords[yi],board.x_size)); }; if(n == 2) { s(0,1); s(1,0); } else if(n == 3) { s(0,1); s(1,0); s(0,0); } @@ -333,6 +335,8 @@ void PlayUtils::placeFixedHandicap(Board& board, int n) { else if(n == 8) { s(0,1); s(1,0); s(0,0); s(1,1); s(0,2); s(1,2); s(2,0); s(2,1); } else if(n == 9) { s(0,1); s(1,0); s(0,0); s(1,1); s(0,2); s(1,2); s(2,0); s(2,1); s(2,2); } else { ASSERT_UNREACHABLE; } + + return locs; } double PlayUtils::getHackedLCBForWinrate(const Search* search, const AnalysisData& data, Player pla) { diff --git a/cpp/program/playutils.h b/cpp/program/playutils.h index c4da1c558..48f853056 100644 --- a/cpp/program/playutils.h +++ b/cpp/program/playutils.h @@ -10,17 +10,10 @@ namespace PlayUtils { //Use the given bot to play free handicap stones, modifying the board and hist in the process and setting the bot's position to it. //Does NOT switch the initial player of the board history to white - void playExtraBlack( - Search* bot, - int numExtraBlack, - Board& board, - BoardHistory& hist, - double temperature, - Rand& gameRand - ); + std::vector playExtraBlack(Search* bot, int numExtraBlack, Board& board, BoardHistory& hist, double temperature, Rand& gameRand); - //Set board to empty and place fixed handicap stones, raising an exception if invalid - void placeFixedHandicap(Board& board, int n); + // Generate fixed handicap stones, raising an exception if invalid + std::vector generateFixedHandicap(const Board& board, int n); ExtraBlackAndKomi chooseExtraBlackAndKomi( float base, float stdev, double allowIntegerProb, diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 1438d9ea1..035a1a08e 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -885,72 +885,82 @@ Player Setup::parseReportAnalysisWinrates( throw StringError("Could not parse config value for reportAnalysisWinratesAs: " + sOrig); } -Rules Setup::loadSingleRules( - ConfigParser& cfg, - bool loadKomi -) { - Rules rules; +Rules Setup::loadSingleRules(ConfigParser& cfg, const bool loadKomi) { + const bool dotsGame = cfg.getBoolOrDefault(DOTS_KEY, false); + Rules rules = Rules::getDefault(dotsGame); if(cfg.contains("rules")) { - if(cfg.contains("koRule")) throw StringError("Cannot both specify 'rules' and individual rules like koRule"); - if(cfg.contains("scoringRule")) throw StringError("Cannot both specify 'rules' and individual rules like scoringRule"); + if(cfg.contains(START_POS_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + START_POS_KEY); + if(cfg.contains(START_POS_RANDOM_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + START_POS_RANDOM_KEY); if(cfg.contains("multiStoneSuicideLegal")) throw StringError("Cannot both specify 'rules' and individual rules like multiStoneSuicideLegal"); - if(cfg.contains("hasButton")) throw StringError("Cannot both specify 'rules' and individual rules like hasButton"); - if(cfg.contains("taxRule")) throw StringError("Cannot both specify 'rules' and individual rules like taxRule"); - if(cfg.contains("whiteHandicapBonus")) throw StringError("Cannot both specify 'rules' and individual rules like whiteHandicapBonus"); - if(cfg.contains("friendlyPassOk")) throw StringError("Cannot both specify 'rules' and individual rules like friendlyPassOk"); - if(cfg.contains("whiteBonusPerHandicapStone")) throw StringError("Cannot both specify 'rules' and individual rules like whiteBonusPerHandicapStone"); + + if (dotsGame) { + if (cfg.contains(DOTS_CAPTURE_EMPTY_BASE_KEY)) throw StringError("Cannot both specify 'rules' and individual rules like " + DOTS_CAPTURE_EMPTY_BASE_KEY); + } else { + if(cfg.contains("koRule")) throw StringError("Cannot both specify 'rules' and individual rules like koRule"); + if(cfg.contains("scoringRule")) throw StringError("Cannot both specify 'rules' and individual rules like scoringRule"); + if(cfg.contains("hasButton")) throw StringError("Cannot both specify 'rules' and individual rules like hasButton"); + if(cfg.contains("taxRule")) throw StringError("Cannot both specify 'rules' and individual rules like taxRule"); + if(cfg.contains("whiteHandicapBonus")) throw StringError("Cannot both specify 'rules' and individual rules like whiteHandicapBonus"); + if(cfg.contains("friendlyPassOk")) throw StringError("Cannot both specify 'rules' and individual rules like friendlyPassOk"); + if(cfg.contains("whiteBonusPerHandicapStone")) throw StringError("Cannot both specify 'rules' and individual rules like whiteBonusPerHandicapStone"); + } rules = Rules::parseRules(cfg.getString("rules"), cfg.getBoolOrDefault(DOTS_KEY, false)); } else { - string koRule = cfg.getString("koRule", Rules::koRuleStrings()); - string scoringRule = cfg.getString("scoringRule", Rules::scoringRuleStrings()); - bool multiStoneSuicideLegal = cfg.getBool("multiStoneSuicideLegal"); - bool hasButton = cfg.contains("hasButton") ? cfg.getBool("hasButton") : false; - float komi = 7.5f; - - rules.koRule = Rules::parseKoRule(koRule); - rules.scoringRule = Rules::parseScoringRule(scoringRule); - rules.multiStoneSuicideLegal = multiStoneSuicideLegal; - rules.hasButton = hasButton; - rules.komi = komi; - - if(cfg.contains("taxRule")) { - string taxRule = cfg.getString("taxRule", Rules::taxRuleStrings()); - rules.taxRule = Rules::parseTaxRule(taxRule); + if (cfg.contains(START_POS_KEY)) { + rules.startPos = Rules::parseStartPos(cfg.getString(START_POS_KEY)); } - else { - rules.taxRule = (rules.scoringRule == Rules::SCORING_TERRITORY ? Rules::TAX_SEKI : Rules::TAX_NONE); + if (cfg.contains(START_POS_RANDOM_KEY)) { + rules.startPosIsRandom = cfg.getBool(START_POS_RANDOM_KEY); } + rules.multiStoneSuicideLegal = cfg.getBoolOrDefault("multiStoneSuicideLegal", rules.multiStoneSuicideLegal); + + if (dotsGame) { + rules.dotsCaptureEmptyBases = cfg.getBoolOrDefault(DOTS_CAPTURE_EMPTY_BASE_KEY, rules.dotsCaptureEmptyBases); + } else { + rules.koRule = Rules::parseKoRule(cfg.getString("koRule", Rules::koRuleStrings())); + rules.scoringRule = Rules::parseScoringRule(cfg.getString("scoringRule", Rules::scoringRuleStrings())); + rules.hasButton = cfg.getBoolOrDefault("hasButton", false); + rules.komi = 7.5f; + + if(cfg.contains("taxRule")) { + string taxRule = cfg.getString("taxRule", Rules::taxRuleStrings()); + rules.taxRule = Rules::parseTaxRule(taxRule); + } + else { + rules.taxRule = (rules.scoringRule == Rules::SCORING_TERRITORY ? Rules::TAX_SEKI : Rules::TAX_NONE); + } - if(rules.hasButton && rules.scoringRule != Rules::SCORING_AREA) - throw StringError("Config specifies hasButton=true on a scoring system other than AREA"); - - //Also handles parsing of legacy option whiteBonusPerHandicapStone - if(cfg.contains("whiteBonusPerHandicapStone") && cfg.contains("whiteHandicapBonus")) - throw StringError("May specify only one of whiteBonusPerHandicapStone and whiteHandicapBonus in config"); - else if(cfg.contains("whiteHandicapBonus")) - rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(cfg.getString("whiteHandicapBonus", Rules::whiteHandicapBonusRuleStrings())); - else if(cfg.contains("whiteBonusPerHandicapStone")) { - int whiteBonusPerHandicapStone = cfg.getInt("whiteBonusPerHandicapStone",0,1); - if(whiteBonusPerHandicapStone == 0) - rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + if(rules.hasButton && rules.scoringRule != Rules::SCORING_AREA) + throw StringError("Config specifies hasButton=true on a scoring system other than AREA"); + + //Also handles parsing of legacy option whiteBonusPerHandicapStone + if(cfg.contains("whiteBonusPerHandicapStone") && cfg.contains("whiteHandicapBonus")) + throw StringError("May specify only one of whiteBonusPerHandicapStone and whiteHandicapBonus in config"); + else if(cfg.contains("whiteHandicapBonus")) + rules.whiteHandicapBonusRule = Rules::parseWhiteHandicapBonusRule(cfg.getString("whiteHandicapBonus", Rules::whiteHandicapBonusRuleStrings())); + else if(cfg.contains("whiteBonusPerHandicapStone")) { + int whiteBonusPerHandicapStone = cfg.getInt("whiteBonusPerHandicapStone",0,1); + if(whiteBonusPerHandicapStone == 0) + rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + else + rules.whiteHandicapBonusRule = Rules::WHB_N; + } else - rules.whiteHandicapBonusRule = Rules::WHB_N; - } - else - rules.whiteHandicapBonusRule = Rules::WHB_ZERO; + rules.whiteHandicapBonusRule = Rules::WHB_ZERO; - if(cfg.contains("friendlyPassOk")) { - rules.friendlyPassOk = cfg.getBool("friendlyPassOk"); - } + if(cfg.contains("friendlyPassOk")) { + rules.friendlyPassOk = cfg.getBool("friendlyPassOk"); + } - //Drop default komi to 6.5 for territory rules, and to 7.0 for button - if(rules.scoringRule == Rules::SCORING_TERRITORY) - rules.komi = 6.5f; - else if(rules.hasButton) - rules.komi = 7.0f; + //Drop default komi to 6.5 for territory rules, and to 7.0 for button + if(rules.scoringRule == Rules::SCORING_TERRITORY) + rules.komi = 6.5f; + else if(rules.hasButton) + rules.komi = 7.0f; + } } if(loadKomi) { diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 468a018fc..1c1872296 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -1335,7 +1335,7 @@ void Search::printTreeHelper( return; } if((options.alsoBranch_ && depth == 0) || (!options.alsoBranch_ && depth == options.branch_.size())) { - out << "---" << PlayerIO::playerToString(node.nextPla) << "(" << (node.nextPla == perspectiveToUse ? "^" : "v") << ")---" << endl; + out << "---" << PlayerIO::playerToString(node.nextPla,rootBoard.isDots()) << "(" << (node.nextPla == perspectiveToUse ? "^" : "v") << ")---" << endl; } vector analysisData; diff --git a/cpp/tests/testrules.cpp b/cpp/tests/testrules.cpp index cfff353ce..b141325ef 100644 --- a/cpp/tests/testrules.cpp +++ b/cpp/tests/testrules.cpp @@ -109,7 +109,7 @@ void Tests::runRulesTests() { if(!hist.isGameFinished) o << "Game is not over" << endl; else { - o << "Winner: " << PlayerIO::playerToString(hist.winner) << endl; + o << "Winner: " << PlayerIO::playerToString(hist.winner,hist.rules.isDots) << endl; o << "W-B Score: " << hist.finalWhiteMinusBlackScore << endl; o << "isNoResult: " << hist.isNoResult << endl; o << "isResignation: " << hist.isResignation << endl; diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index de80182f3..fe6882ddd 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -57,12 +57,12 @@ void Tests::runSgfTests() { } out << "Initial board hist " << endl; - out << "pla " << PlayerIO::playerToString(pla) << endl; + out << "pla " << PlayerIO::playerToString(pla,rules.isDots) << endl; hist.printDebugInfo(out,board); auto [finalHist, finalBoard] = sgf->setupBoardAndHistAssumeLegal(rules, pla, sgf->moves.size()); out << "Final board hist " << endl; - out << "pla " << PlayerIO::playerToString(pla) << endl; + out << "pla " << PlayerIO::playerToString(pla,rules.isDots) << endl; finalHist.printDebugInfo(out,finalBoard); { @@ -99,7 +99,7 @@ void Tests::runSgfTests() { } out << "handicapValue " << sgf->getHandicapValue() << endl; - out << "sgfWinner " << PlayerIO::playerToString(sgf->getSgfWinner()) << endl; + out << "sgfWinner " << PlayerIO::playerToString(sgf->getSgfWinner(), sgf->isDotsGame()) << endl; out << "firstPlayerColor " << PlayerIO::colorToChar(sgf->getFirstPlayerColor()) << endl; out << "black rank " << sgf->getRank(P_BLACK) << endl; @@ -133,7 +133,7 @@ void Tests::runSgfTests() { sgf->getPlacements(placements, xySize.x, xySize.y); out << "placements " << placements.size() << endl; for(const Move& move: placements) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; + out << PlayerIO::playerToString(move.pla, sgf->isDotsGame()) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -141,7 +141,7 @@ void Tests::runSgfTests() { sgf->getMoves(moves, xySize.x, xySize.y); out << "moves " << moves.size() << endl; for(const Move& move: moves) { - out << PlayerIO::playerToString(move.pla) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; + out << PlayerIO::playerToString(move.pla, sgf->isDotsGame()) << " " << Location::toString(move.loc, xySize.x, xySize.y, sgf->isDotsGame()) << " "; } out << endl; @@ -178,26 +178,26 @@ depth 17 komi 0 startPos CROSS placements -X E3 +X 5-3 moves -O D4 -X D3 -O H5 -X C4 -O H3 -X C5 -O F3 -X D6 -O C3 -X H7 -O J6 -X D7 -O G2 -X D8 -O B2 +O 4-4 +X 4-3 +O 8-5 +X 3-4 +O 8-3 +X 3-5 +O 6-3 +X 4-6 +O 3-3 +X 8-7 +O 9-6 +X 4-7 +O 7-2 +X 4-8 +O 2-2 X ground Initial board hist -pla White +pla Player2 HASH: 42AC4303D65557034CC3593CB26EA615 1 2 3 4 5 6 7 8 9 10 8 . . . . . . . . . . @@ -212,18 +212,18 @@ HASH: 42AC4303D65557034CC3593CB26EA615 Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 White bonus score 0 -Presumed next pla White +Presumed next pla Player2 Game result 0 Empty 0 0 0 0 Last moves Final board hist -pla White +pla Player2 HASH: AB87C4395AA2D7E5D7B069ACBFA701D5 1 2 3 4 5 6 7 8 9 10 8 . . . X . . . . . . - 7 . . . X . . . O . . + 7 . . . X . . . x . . 6 . . . X . . . . O . - 5 . . X X X O . O . . - 4 . . X X X X . . . . + 5 . . X ' X O . O . . + 4 . . X o o X . . . . 3 . . O X X O . O . . 2 . O . . . . O . . . 1 . . . . . . . . . . @@ -231,9 +231,9 @@ HASH: AB87C4395AA2D7E5D7B069ACBFA701D5 Rules dotsCaptureEmptyBase0startPosCROSSsui1komi0 White bonus score 0 -Presumed next pla White -Game result 1 Black -1 1 0 0 -Last moves D4 D3 H5 C4 H3 C5 F3 D6 C3 H7 J6 D7 G2 D8 B2 ground +Presumed next pla Player2 +Game result 1 Player1 -1 1 0 0 +Last moves 4-4 4-3 8-5 3-4 8-3 3-5 6-3 4-6 3-3 8-7 9-6 4-7 7-2 4-8 2-2 ground )"; expect(name,out,expected); } diff --git a/cpp/tests/testtrainingwrite.cpp b/cpp/tests/testtrainingwrite.cpp index 342d8bf8a..fe38b91fa 100644 --- a/cpp/tests/testtrainingwrite.cpp +++ b/cpp/tests/testtrainingwrite.cpp @@ -938,9 +938,9 @@ xxxxxxxx. FinishedGameData* data = gameRunner->runGame(seed, botSpec, botSpec, forkData, NULL, logger, shouldStop, shouldPause, nullptr, nullptr, nullptr); cout << data->startHist.rules << endl; cout << "Start moves size " << data->startHist.moveHistory.size() - << " Start pla " << PlayerIO::playerToString(data->startPla) + << " Start pla " << PlayerIO::playerToString(data->startPla,data->startHist.rules.isDots) << " XY " << data->startBoard.x_size << " " << data->startBoard.y_size - << " Extra black " << data->numExtraBlack + << " Extra " << PlayerIO::playerToString(P_BLACK, data->startHist.rules.isDots) << " " << data->numExtraBlack << " Draw equiv " << data->drawEquivalentWinsForWhite << " Mode " << data->mode << " BeganInEncorePhase " << data->beganInEncorePhase @@ -2846,7 +2846,7 @@ void Tests::runSekiTrainWriteTests(const string& modelFile) { vector buf; vector isAlive = PlayUtils::computeAnticipatedStatusesWithOwnership(bot,board,hist,pla,numVisits,buf); testAssert(bot->alwaysIncludeOwnerMap == false); - cout << "Search assumes " << PlayerIO::playerToString(pla) << " first" << endl; + cout << "Search assumes " << PlayerIO::playerToString(pla,hist.rules.isDots) << " first" << endl; cout << "Rules " << hist.rules << endl; cout << board << endl; for(int y = 0; y Date: Sun, 5 Oct 2025 21:17:27 +0200 Subject: [PATCH 28/42] [Python] Refine Dots NN inputs and scoring However, it should be restored later when NN is unified with Go models (22 for spatial and 19 for global) --- .../katago/train/data_processing_pytorch.py | 153 ++++++++++++------ python/katago/train/model_pytorch.py | 14 +- python/katago/train/modelconfigs.py | 51 +++++- 3 files changed, 166 insertions(+), 52 deletions(-) diff --git a/python/katago/train/data_processing_pytorch.py b/python/katago/train/data_processing_pytorch.py index 78a50daeb..754a0ef4c 100644 --- a/python/katago/train/data_processing_pytorch.py +++ b/python/katago/train/data_processing_pytorch.py @@ -1,5 +1,6 @@ import logging import os +from enum import Enum, auto import numpy as np from concurrent.futures import ThreadPoolExecutor @@ -156,71 +157,133 @@ def apply_symmetry(tensor, symm): if symm == 7: return tensor.flip(-2) +class GoSpatialFeature(Enum): + ON_BOARD = 0 + PLA_STONE = 1 + OPP_STONE = 2 + LIBERTIES_1 = 3 + LIBERTIES_2 = 4 + LIBERTIES_3 = 5 + SUPER_KO_BANNED = 6 + KO_RECAP_BLOCKED = 7 + KO_EXTRA = 8 + PREV_1_LOC = 9 + PREV_2_LOC = 10 + PREV_3_LOC = 11 + PREV_4_LOC = 12 + PREV_5_LOC = 13 + LADDER_CAPTURED = 14 + LADDER_CAPTURED_PREVIOUS_1 = 15 + LADDER_CAPTURED_PREVIOUS_2 = 16 + LADDER_WORKING_MOVES = 17 + AREA_PLA = 18 + AREA_OPP = 19 + SECOND_ENCORE_PLA = 20 + SECOND_ENCORE_OPP = 21 + +class GoGlobalFeature(Enum): + PREV_1_LOC_PASS = 0 + PREV_2_LOC_PASS = 1 + PREV_3_LOC_PASS = 2 + PREV_4_LOC_PASS = 3 + PREV_5_LOC_PASS = 4 + KOMI = 5 + KO_RULE_NOT_SIMPLE = 6 + KO_RULE_EXTRA = 7 + SUICIDE = 8 + SCORING_TERRITORY = 9 + TAX_SEKI = 10 + TAX_ALL = 11 + ENCORE_PHASE_1 = 12 + ENCORE_PHASE_2 = 13 + PASS_WOULD_END_PHASE = 14 + PLAYOUT_DOUBLING_ADVANTAGE_FLAG = 15 + PLAYOUT_DOUBLING_ADVANTAGE_VALUE = 16 + HAS_BUTTON = 17 + BOARD_SIZE_KOMI_PARITY = 18 + +class DotsSpatialFeature(Enum): + ON_BOARD = 0 + PLA_ACTIVE = auto() + OPP_ACTIVE = auto() + PLA_PLACED = auto() + OPP_PLACED = auto() + DEAD = auto() + GROUNDED = auto() + PLA_CAPTURES = auto() + OPP_CAPTURES = auto() + PLA_SURROUNDINGS = auto() + OPP_SURROUNDINGS = auto() + PREV_1_LOC = auto() + PREV_2_LOC = auto() + PREV_3_LOC = auto() + PREV_4_LOC = auto() + PREV_5_LOC = auto() + LADDER_CAPTURED = auto() + LADDER_CAPTURED_PREVIOUS_1 = auto() + LADDER_CAPTURED_PREVIOUS_2 = auto() + LADDER_WORKING_MOVES = auto() def build_history_matrices(model_config: modelconfigs.ModelConfig, device): num_bin_features = modelconfigs.get_num_bin_input_features(model_config) - assert num_bin_features == 22, "Currently this code is hardcoded for this many features" - - h_base = torch.diag( - torch.tensor( - [ - 1.0, # 0 - 1.0, # 1 - 1.0, # 2 - 1.0, # 3 - 1.0, # 4 - 1.0, # 5 - 1.0, # 6 - 1.0, # 7 - 1.0, # 8 - 0.0, # 9 Location of move 1 turn ago - 0.0, # 10 Location of move 2 turns ago - 0.0, # 11 Location of move 3 turns ago - 0.0, # 12 Location of move 4 turns ago - 0.0, # 13 Location of move 5 turns ago - 1.0, # 14 Ladder-threatened stone - 0.0, # 15 Ladder-threatened stone, 1 turn ago - 0.0, # 16 Ladder-threatened stone, 2 turns ago - 1.0, # 17 - 1.0, # 18 - 1.0, # 19 - 1.0, # 20 - 1.0, # 21 - ], - device=device, - requires_grad=False, - ) - ) + + is_go_game = not modelconfigs.is_dots_game(model_config) + + prev_1_loc = GoSpatialFeature.PREV_1_LOC.value if is_go_game else DotsSpatialFeature.PREV_1_LOC.value + prev_2_loc = GoSpatialFeature.PREV_2_LOC.value if is_go_game else DotsSpatialFeature.PREV_2_LOC.value + prev_3_loc = GoSpatialFeature.PREV_3_LOC.value if is_go_game else DotsSpatialFeature.PREV_3_LOC.value + prev_4_loc = GoSpatialFeature.PREV_4_LOC.value if is_go_game else DotsSpatialFeature.PREV_4_LOC.value + prev_5_loc = GoSpatialFeature.PREV_5_LOC.value if is_go_game else DotsSpatialFeature.PREV_5_LOC.value + + ladder_captured = GoSpatialFeature.LADDER_CAPTURED.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED.value + ladder_captured_previous_1 = GoSpatialFeature.LADDER_CAPTURED_PREVIOUS_1.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED_PREVIOUS_1.value + ladder_captured_previous_2 = GoSpatialFeature.LADDER_CAPTURED_PREVIOUS_2.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED_PREVIOUS_2.value + ladder_working_moves = GoSpatialFeature.LADDER_WORKING_MOVES.value if is_go_game else DotsSpatialFeature.LADDER_WORKING_MOVES.value + + data = [1.0 for _ in range(num_bin_features)] + + data[prev_1_loc] = 0.0 + data[prev_2_loc] = 0.0 + data[prev_3_loc] = 0.0 + data[prev_4_loc] = 0.0 + data[prev_5_loc] = 0.0 + + data[ladder_captured] = 1.0 + data[ladder_captured_previous_1] = 0.0 + data[ladder_captured_previous_2] = 0.0 + + h_base = torch.diag(torch.tensor(data, device=device, requires_grad=False)) + # Because we have ladder features that express past states rather than past diffs, # the most natural encoding when we have no history is that they were always the # same, rather than that they were all zero. So rather than zeroing them we have no # history, we add entries in the matrix to copy them over. # By default, without history, the ladder features 15 and 16 just copy over from 14. - h_base[14, 15] = 1.0 - h_base[14, 16] = 1.0 + h_base[ladder_captured, ladder_captured_previous_1] = 1.0 + h_base[ladder_captured, ladder_captured_previous_2] = 1.0 h0 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) # When have the prev move, we enable feature 9 and 15 - h0[9, 9] = 1.0 # Enable 9 -> 9 - h0[14, 15] = -1.0 # Stop copying 14 -> 15 - h0[14, 16] = -1.0 # Stop copying 14 -> 16 - h0[15, 15] = 1.0 # Enable 15 -> 15 - h0[15, 16] = 1.0 # Start copying 15 -> 16 + h0[prev_1_loc, prev_1_loc] = 1.0 # Enable 9 -> 9 + h0[ladder_captured, ladder_captured_previous_1] = -1.0 # Stop copying 14 -> 15 + h0[ladder_captured, ladder_captured_previous_2] = -1.0 # Stop copying 14 -> 16 + h0[ladder_captured_previous_1, ladder_captured_previous_1] = 1.0 # Enable 15 -> 15 + h0[ladder_captured_previous_1, ladder_captured_previous_2] = 1.0 # Start copying 15 -> 16 h1 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) # When have the prevprev move, we enable feature 10 and 16 - h1[10, 10] = 1.0 # Enable 10 -> 10 - h1[15, 16] = -1.0 # Stop copying 15 -> 16 - h1[16, 16] = 1.0 # Enable 16 -> 16 + h1[prev_2_loc, prev_2_loc] = 1.0 # Enable 10 -> 10 + h1[ladder_captured_previous_1, ladder_captured_previous_2] = -1.0 # Stop copying 15 -> 16 + h1[ladder_captured_previous_2, ladder_captured_previous_2] = 1.0 # Enable 16 -> 16 h2 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h2[11, 11] = 1.0 + h2[prev_3_loc, prev_3_loc] = 1.0 h3 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h3[12, 12] = 1.0 + h3[prev_4_loc, prev_4_loc] = 1.0 h4 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h4[13, 13] = 1.0 + h4[prev_5_loc, prev_5_loc] = 1.0 # (1, n_bin, n_bin) h_base = h_base.reshape((1, num_bin_features, num_bin_features)) diff --git a/python/katago/train/model_pytorch.py b/python/katago/train/model_pytorch.py index 4654ac934..cf1209614 100644 --- a/python/katago/train/model_pytorch.py +++ b/python/katago/train/model_pytorch.py @@ -8,6 +8,7 @@ import packaging.version from typing import List, Dict, Optional, Set +from .modelconfigs import get_num_bin_input_features, get_num_global_input_features from ..train import modelconfigs EXTRA_SCORE_DISTR_RADIUS = 60 @@ -1652,19 +1653,22 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: self.activation = "relu" if "activation" not in config else config["activation"] + spatial_features = get_num_bin_input_features(config) + global_features = get_num_global_input_features(config) + if config["initial_conv_1x1"]: - self.conv_spatial = torch.nn.Conv2d(22, self.c_trunk, kernel_size=1, padding="same", bias=False) + self.conv_spatial = torch.nn.Conv2d(spatial_features, self.c_trunk, kernel_size=1, padding="same", bias=False) else: - self.conv_spatial = torch.nn.Conv2d(22, self.c_trunk, kernel_size=3, padding="same", bias=False) - self.linear_global = torch.nn.Linear(19, self.c_trunk, bias=False) + self.conv_spatial = torch.nn.Conv2d(spatial_features, self.c_trunk, kernel_size=3, padding="same", bias=False) + self.linear_global = torch.nn.Linear(global_features, self.c_trunk, bias=False) if "metadata_encoder" in config and config["metadata_encoder"] is not None: self.metadata_encoder = MetadataEncoder(config) else: self.metadata_encoder = None - self.bin_input_shape = [22, pos_len_x, pos_len_y] - self.global_input_shape = [19] + self.bin_input_shape = [spatial_features, pos_len_x, pos_len_y] + self.global_input_shape = [global_features] self.blocks = torch.nn.ModuleList() for block_config in self.block_kind: diff --git a/python/katago/train/modelconfigs.py b/python/katago/train/modelconfigs.py index 976be7178..f5b105eaa 100644 --- a/python/katago/train/modelconfigs.py +++ b/python/katago/train/modelconfigs.py @@ -41,20 +41,35 @@ # version = 15 # V7 features, Extra nonlinearity for pass output # version = 16 # V7 features, Q value predictions in the policy head +# version = 17 # V8 features, Extra nonlinearity for pass output, Dots game + def get_version(config: ModelConfig): return config["version"] +def is_dots_game(config: ModelConfig): + version = get_version(config) + if 10 <= version <= 16: + return False + elif version == 17: + return True + else: + assert(False) + def get_num_bin_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: return 22 + elif version == 17: # Dots game + return 20 else: assert(False) def get_num_global_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: return 19 + elif version == 17: # Dots game + return 9 else: assert(False) @@ -199,6 +214,37 @@ def get_num_meta_encoder_input_features(config_or_meta_encoder_version: Union[Mo "v2_size":80, } +b10c128_dots = { + "version":17, + "norm_kind":"fixup", + "bnorm_epsilon": 1e-4, + "bnorm_running_avg_momentum": 0.001, + "initial_conv_1x1": False, + "trunk_num_channels":128, + "mid_num_channels":128, + "gpool_num_channels":32, + "use_attention_pool":False, + "num_attention_pool_heads":4, + "block_kind": [ + ["rconv1","regular"], + ["rconv2","regular"], + ["rconv3","regular"], + ["rconv4","regular"], + ["rconv5","regulargpool"], + ["rconv6","regular"], + ["rconv7","regular"], + ["rconv8","regulargpool"], + ["rconv9","regular"], + ["rconv10","regular"], + ], + "p1_num_channels":32, + "g1_num_channels":32, + "v1_num_channels":32, + "sbv2_num_channels":64, + "num_scorebeliefs":6, + "v2_size":80, +} + b5c192nbt = { "version":15, "norm_kind":"fixup", @@ -1452,6 +1498,7 @@ def get_num_meta_encoder_input_features(config_or_meta_encoder_version: Union[Mo # Small model configs, not too different in inference cost from b10c128 "b10c128": b10c128, + "b10c128_dots": b10c128_dots, "b5c192nbt": b5c192nbt, # Medium model configs, not too different in inference cost from b15c192 From ca4074ae0f045d4f8b651b83a9fa283b444945f9 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 5 Oct 2025 21:19:12 +0200 Subject: [PATCH 29/42] Move stress tests down for convenience (because they are quite slow) Run Global utils tests --- cpp/command/runtests.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cpp/command/runtests.cpp b/cpp/command/runtests.cpp index a4c17eb2c..62e2ef83b 100644 --- a/cpp/command/runtests.cpp +++ b/cpp/command/runtests.cpp @@ -30,26 +30,25 @@ int MainCmds::runtests(const vector& args) { Board::initHash(); ScoreValue::initTables(); + Global::runTests(); + BSearch::runTests(); + Rand::runTests(); + DateTime::runTests(); + FancyMath::runTests(); + ComputeElos::runTests(); + Base64::runTests(); + ThreadTest::runTests(); + Tests::runDotsFieldTests(); Tests::runDotsGroundingTests(); Tests::runDotsBoardHistoryGroundingTests(); Tests::runDotsPosHashTests(); Tests::runDotsStartPosTests(); - Tests::runDotsStressTests(); - Tests::runDotsSymmetryTests(); Tests::runDotsOwnershipTests(); Tests::runDotsCapturingTests(); - BSearch::runTests(); - Rand::runTests(); - DateTime::runTests(); - FancyMath::runTests(); - ComputeElos::runTests(); - Base64::runTests(); - ThreadTest::runTests(); - Tests::runBoardIOTests(); Tests::runBoardBasicTests(); @@ -59,7 +58,6 @@ int MainCmds::runtests(const vector& args) { Tests::runBoardUndoTest(); Tests::runBoardHandicapTest(); - Tests::runBoardStressTest(); Tests::runSgfTests(); Tests::runBasicSymmetryTests(); @@ -67,6 +65,9 @@ int MainCmds::runtests(const vector& args) { Tests::runSymmetryDifferenceTests(); Tests::runBoardReplayTest(); + Tests::runDotsStressTests(); + Tests::runBoardStressTest(); + ScoreValue::freeTables(); Tests::runInlineConfigTests(); From 44457e47d6c18e2b22485a8daed85db0a0e471c8 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 11 Oct 2025 20:56:54 +0200 Subject: [PATCH 30/42] [GTP] Support multiple moves for `play` command Fix `set_position` to recognize start position Fix start pos recognizer Remove `START_POS_CUSTOM` because now all unrecognized moves are accessible via passed `remainingMoves` --- cpp/command/gtp.cpp | 163 ++++++++++++++++++------------- cpp/dataio/sgf.cpp | 6 +- cpp/game/board.cpp | 10 +- cpp/game/common.h | 4 + cpp/game/rules.cpp | 140 ++++++++++++++------------ cpp/game/rules.h | 11 ++- cpp/neuralnet/nninputs.cpp | 3 +- cpp/tests/testdotsextra.cpp | 10 +- cpp/tests/testdotsstartposes.cpp | 143 +++++++++++++++++++++++++-- cpp/tests/testdotsutils.cpp | 14 ++- cpp/tests/testdotsutils.h | 4 + cpp/tests/testsgf.cpp | 10 +- 12 files changed, 354 insertions(+), 164 deletions(-) diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index e52443269..8a296ac3a 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -1,21 +1,22 @@ -#include "../core/global.h" +#include + +#include "../command/commandline.h" #include "../core/commandloop.h" #include "../core/config_parser.h" -#include "../core/fileutils.h" -#include "../core/timer.h" #include "../core/datetime.h" -#include "../core/makedir.h" +#include "../core/fileutils.h" +#include "../core/global.h" #include "../core/test.h" +#include "../core/timer.h" #include "../dataio/sgf.h" -#include "../search/searchnode.h" +#include "../main.h" +#include "../program/play.h" +#include "../program/playutils.h" +#include "../program/setup.h" #include "../search/asyncbot.h" #include "../search/patternbonustable.h" -#include "../program/setup.h" -#include "../program/playutils.h" -#include "../program/play.h" +#include "../search/searchnode.h" #include "../tests/tests.h" -#include "../command/commandline.h" -#include "../main.h" using namespace std; @@ -603,9 +604,19 @@ struct GTPEngine { assert(bot->getRootHist().rules == currentRules); const int newXSize = bot->getRootBoard().x_size; const int newYSize = bot->getRootBoard().y_size; + + vector startPosMoves; + bool startPosIsRandom; + vector remainingPlacementMoves; + const int startPos = + Rules::recognizeStartPos(initialStones, newXSize, newYSize, startPosMoves, startPosIsRandom, &remainingPlacementMoves); + + currentRules.startPos = startPos; + currentRules.startPosIsRandom = startPosIsRandom; + Board board(newXSize,newYSize,currentRules); - board.setStartPos(gtpRand); - if(!board.setStonesFailIfNoLibs(initialStones)) return false; + if (!board.setStonesFailIfNoLibs(startPosMoves, true)) return false; + if (!board.setStonesFailIfNoLibs(remainingPlacementMoves)) return false; //Sanity check for (const auto initialStone : initialStones) { @@ -614,7 +625,7 @@ struct GTPEngine { return false; } } - const Player pla = P_BLACK; + constexpr Player pla = P_BLACK; BoardHistory hist(board,pla,currentRules,0); hist.setInitialTurnNumber(board.numStonesOnBoard()); // Heuristic to guess at what turn this is const vector newMoveHistory; @@ -1876,6 +1887,43 @@ static GTPEngine::AnalyzeArgs parseAnalyzeCommand( return args; } +optional parseMovesSequence(const vector& pieces, const Board& board, bool passIsAllowed, vector& movesToPlay) { + optional response = std::nullopt; + + auto renderPieces = [pieces](const int& index) { + const vector subvector(pieces.begin(), pieces.begin() + index + 1); + return " (" + Global::concat(subvector, " ") + ")"; + }; + + for (int pieceInd = 0; pieceInd < pieces.size(); pieceInd += 2) { + Player pla; + Loc loc; + if(!PlayerIO::tryParsePlayer(pieces[pieceInd], pla)) { + response = "Could not parse color: '" + pieces[pieceInd] + "'" + renderPieces(pieceInd); + break; + } + + const int locIndex = pieceInd + 1; + if (locIndex >= pieces.size()) { + response = "Expected location after color: '" + pieces[pieceInd] + "'" + renderPieces(pieceInd); + break; + } + + if (!tryParseLoc(pieces[locIndex], board, loc)) { + response = "Could not parse location: '" + pieces[locIndex] + "'" + renderPieces(locIndex); + break; + } + + if (loc == Board::PASS_LOC && !passIsAllowed) { + response = Location::toString(loc, board) + " is disallowed" + renderPieces(locIndex); + break; + } + + movesToPlay.emplace_back(loc, pla); + } + + return response; +} int MainCmds::gtp(const vector& args) { Board::initHash(); @@ -2915,69 +2963,44 @@ int MainCmds::gtp(const vector& args) { } else if(command == "play") { - Player pla; - Loc loc; - if(pieces.size() != 2) { - responseIsError = true; - response = "Expected two arguments for play but got '" + Global::concat(pieces," ") + "'"; - } - else if(!PlayerIO::tryParsePlayer(pieces[0],pla)) { - responseIsError = true; - response = "Could not parse color: '" + pieces[0] + "'"; - } - else if(!tryParseLoc(pieces[1],engine->bot->getRootBoard(),loc)) { - responseIsError = true; - response = "Could not parse vertex: '" + pieces[1] + "'"; - } - else { - bool suc = engine->play(loc,pla); - if(!suc) { - responseIsError = true; - response = "illegal move"; - } - maybeStartPondering = true; - } - } + vector movesToPlay; + const Board& rootBoard = engine->bot->getRootBoard(); - else if(command == "set_position") { - if(pieces.size() % 2 != 0) { - responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "'"; - } - else { - vector initialStones; - for(int i = 0; ibot->getRootBoard(),loc)) { - responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "': "; - response += "could not parse vertex: '" + pieces[i+1] + "'"; - break; - } - else if(loc == Board::PASS_LOC) { + if (auto parseError = parseMovesSequence(pieces, rootBoard, true, movesToPlay); parseError == std::nullopt) { + for (int moveInd = 0; moveInd < movesToPlay.size(); moveInd++) { + if (Move move = movesToPlay[moveInd]; !engine->play(move.loc, move.pla)) { responseIsError = true; - response = "Expected a space-separated sequence of pairs but got '" + Global::concat(pieces," ") + "': "; - response += "could not parse vertex: '" + pieces[i+1] + "'"; + response = "Illegal move " + PlayerIO::playerToString(move.pla, rootBoard.isDots()) + " " + Location::toString(move.loc, rootBoard); + for (int rollbackMoveInd = 0; rollbackMoveInd < moveInd; rollbackMoveInd++) { + assert(engine->undo()); // Rollback already placed moves + } break; } - initialStones.emplace_back(loc,pla); } - if(!responseIsError) { - maybeSaveAvoidPatterns(false); - bool suc = engine->setPosition(initialStones); - if(!suc) { - responseIsError = true; - response = "Illegal stone placements - overlapping stones or stones with no liberties?"; - } - maybeStartPondering = false; + + if (!responseIsError) { + maybeStartPondering = true; } + } else { + responseIsError = true; + response = parseError.value(); + } + } + + else if(command == "set_position") { + vector initialStones; + const Board& rootBoard = engine->bot->getRootBoard(); + + if (auto parseError = parseMovesSequence(pieces, rootBoard, false, initialStones); parseError == std::nullopt) { + maybeSaveAvoidPatterns(false); + if(!engine->setPosition(initialStones)) { + responseIsError = true; + response = "Illegal stone placements - overlapping stones or stones with no liberties?"; + } + maybeStartPondering = false; + } else { + responseIsError = true; + response = parseError.value(); } } diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index ba1d174dc..aae231ad9 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -151,9 +151,10 @@ static Rules getRulesFromSgf(const SgfNode& rootNode, const int xSize, const int vector placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); + vector startPosMoves; bool randomized; vector remainingMoves; - rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, randomized, &remainingMoves); + rules.startPos = Rules::recognizeStartPos(placementMoves, xSize, ySize, startPosMoves, randomized, &remainingMoves); if (randomized && !rules.startPosIsRandom) { propertyFail("Defined start pos is randomized but RU says it shouldn't"); } @@ -1772,9 +1773,10 @@ Rules CompactSgf::getRulesOrWarn(const Rules& defaultRules, std::function placementMoves; rootNode.accumPlacements(placementMoves, xSize, ySize); + vector startPosMoves; bool randomized; vector remainingMoves; - rules.startPos = Rules::tryRecognizeStartPos(placementMoves, xSize, ySize, randomized, &remainingMoves); + rules.startPos = Rules::recognizeStartPos(placementMoves, xSize, ySize, startPosMoves, randomized, &remainingMoves); if (randomized && !rules.startPosIsRandom) { f("Defined start pos is randomized but RU says it shouldn't"); } diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index f0a43aaa8..31beeba95 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -799,7 +799,7 @@ bool Board::setStoneFailIfNoLibs(Loc loc, Color color, const bool startPos) { void Board::setStartPos(Rand& rand) { const vector startPos = Rules::generateStartPos(rules.startPos, rules.startPosIsRandom ? &rand : nullptr, x_size, y_size); - bool success = setStonesFailIfNoLibs(startPos, true); + const bool success = setStonesFailIfNoLibs(startPos, true); assert(success); } @@ -2525,9 +2525,7 @@ bool Board::isEqualForTesting(const Board& other, const bool checkNumCaptures, return false; } for (int i = 0; i < start_pos_moves.size(); i++) { - const Move start_pose_move = start_pos_moves[i]; - const Move other_start_pos_move = other.start_pos_moves[i]; - if (start_pose_move.loc != other_start_pos_move.loc || start_pose_move.pla != other_start_pos_move.pla) { + if (!movesEqual(start_pos_moves[i], other.start_pos_moves[i])) { return false; } } @@ -2615,11 +2613,11 @@ string Location::toStringMach(const Loc loc, const int x_size, const bool isDots return string(buf); } -string Location::toString(const Loc loc, int x_size, int y_size, bool isDots) { +string Location::toString(const Loc loc, const int x_size, const int y_size, const bool isDots) { if(loc == Board::PASS_LOC) return isDots ? "ground" : "pass"; if(loc == Board::NULL_LOC) - return string("null"); + return "null"; const int x = getX(loc, x_size); const int y = getY(loc, x_size); if(x >= x_size || x < 0 || y < 0 || y >= y_size) diff --git a/cpp/game/common.h b/cpp/game/common.h index 4daed7a5c..d6780850d 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -38,4 +38,8 @@ typedef short Loc; //Simple structure for storing moves. This is a convenient place to define it. STRUCT_NAMED_PAIR(Loc,loc,Player,pla,Move); +inline bool movesEqual(const Move& m1, const Move& m2) { + return m1.loc == m2.loc && m1.pla == m2.pla; +} + #endif diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index 1138b891e..dbd98a43a 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -160,14 +160,12 @@ set Rules::startPosStrings() { startPosIdToName[START_POS_CROSS], startPosIdToName[START_POS_CROSS_2], startPosIdToName[START_POS_CROSS_4], - startPosIdToName[START_POS_CUSTOM] }; } int Rules::getNumOfStartPosStones() const { switch (startPos) { case START_POS_EMPTY: - case START_POS_CUSTOM: return 0; case START_POS_CROSS: return 4; case START_POS_CROSS_2: return 8; @@ -820,53 +818,46 @@ void Rules::addCross(const int x, const int y, const int x_size, const bool rota Player pla; Player opp; - // The end move should always be P_WHITE + // The end move should always be P_WHITE to keep a consistent moves order if (!rotate90) { - pla = P_WHITE; - opp = P_BLACK; - } else { pla = P_BLACK; opp = P_WHITE; + } else { + pla = P_WHITE; + opp = P_BLACK; } - const auto tailMove = Move(Location::getLoc(x, y + 1, x_size), pla); + const auto tailMove = Move(Location::getLoc(x, y, x_size), pla); if (!rotate90) { moves.push_back(tailMove); } - moves.emplace_back(Location::getLoc(x + 1, y + 1, x_size), opp); - moves.emplace_back(Location::getLoc(x + 1, y, x_size), pla); - moves.emplace_back(Location::getLoc(x, y, x_size), opp); + moves.emplace_back(Location::getLoc(x + 1, y, x_size), opp); + moves.emplace_back(Location::getLoc(x + 1, y + 1, x_size), pla); + moves.emplace_back(Location::getLoc(x, y + 1, x_size), opp); if (rotate90) { moves.push_back(tailMove); } } -int Rules::tryRecognizeStartPos( +int Rules::recognizeStartPos( const vector& placementMoves, const int x_size, const int y_size, + vector& startPosMoves, bool& randomized, vector* remainingMoves) { + + startPosMoves = vector(); randomized = false; // Empty or unknown start pos is static by default if (remainingMoves != nullptr) { *remainingMoves = placementMoves; } - int result = START_POS_EMPTY; + int resultStartPos = START_POS_EMPTY; - if(placementMoves.empty()) return result; - - // If all placement moves are black, then it's a handicap game and the start pos is empty - for (const auto placementMove : placementMoves) { - if (placementMove.pla != C_BLACK) { - result = START_POS_CUSTOM; - break; - } - } - - if (result == START_POS_EMPTY) return result; + if(placementMoves.empty()) return resultStartPos; const int stride = x_size + 1; auto placement = vector(stride * (y_size + 2), C_EMPTY); @@ -875,8 +866,6 @@ int Rules::tryRecognizeStartPos( placement[move.loc] = move.pla; } - auto recognizedCrossesMoves = vector(); - for (const auto move : placementMoves) { const int x = Location::getX(move.loc, x_size); const int y = Location::getY(move.loc, x_size); @@ -899,10 +888,18 @@ int Rules::tryRecognizeStartPos( const int xyp1 = Location::getLoc(x, y + 1, x_size); if (placement[xyp1] != opp) continue; - recognizedCrossesMoves.emplace_back(xy, pla); - recognizedCrossesMoves.emplace_back(xp1y, opp); - recognizedCrossesMoves.emplace_back(xp1yp1, pla); - recognizedCrossesMoves.emplace_back(xyp1, opp); + // Match order of generator (white move is always last) + if (pla == P_BLACK) { + startPosMoves.emplace_back(xy, pla); + } + + startPosMoves.emplace_back(xp1y, opp); + startPosMoves.emplace_back(xp1yp1, pla); + startPosMoves.emplace_back(xyp1, opp); + + if (pla != P_BLACK) { + startPosMoves.emplace_back(xy, pla); + } // Clear the placement because the recognized cross is already stored placement[xy] = C_EMPTY; @@ -911,61 +908,80 @@ int Rules::tryRecognizeStartPos( placement[xyp1] = C_EMPTY; } - // Sort locs because start pos is invariant to moves order - auto sortByLoc = [&](vector& moves) { - std::sort(moves.begin(), moves.end(), [](const Move& move1, const Move& move2) { return move1.loc < move2.loc; }); - }; - - sortByLoc(recognizedCrossesMoves); - // Try to match strictly and set up randomized if failed. - auto detectRandomization = [&](const int expectedStartPos) -> void { + // Also, refine start pos and remaining moves. + auto finishRecognition = [&](const int expectedStartPos) -> void { auto staticStartPosMoves = generateStartPos(expectedStartPos, nullptr, x_size, y_size); - assert(remainingMoves != nullptr || placementMoves.size() == recognizedCrossesMoves.size()); - assert(staticStartPosMoves.size() == recognizedCrossesMoves.size()); + assert(remainingMoves != nullptr || placementMoves.size() == startPosMoves.size()); + const auto startPosMovesSize = staticStartPosMoves.size(); + assert(startPosMovesSize <= startPosMoves.size()); + + vector refinedStartPosMoves; + + // TODO: generalize (currently it works only for crosses) + for(int startPosInd = 0; startPosInd < startPosMoves.size(); startPosInd += 4) { + if (startPosInd + 3 >= startPosMoves.size()) continue; + for (int staticStartPosInd = 0; staticStartPosInd < staticStartPosMoves.size(); staticStartPosInd += 4) { + if (staticStartPosInd + 3 >= staticStartPosMoves.size()) continue; - sortByLoc(staticStartPosMoves); + if (!movesEqual(startPosMoves[startPosInd], staticStartPosMoves[staticStartPosInd])) continue; + if (!movesEqual(startPosMoves[startPosInd + 1], staticStartPosMoves[staticStartPosInd + 1])) continue; + if (!movesEqual(startPosMoves[startPosInd + 2], staticStartPosMoves[staticStartPosInd + 2])) continue; + if (!movesEqual(startPosMoves[startPosInd + 3], staticStartPosMoves[staticStartPosInd + 3])) continue; - for(size_t i = 0; i < staticStartPosMoves.size(); i++) { - if(staticStartPosMoves[i].loc != recognizedCrossesMoves[i].loc || staticStartPosMoves[i].pla != recognizedCrossesMoves[i].pla) { - randomized = true; + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 1]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 2]); + refinedStartPosMoves.emplace_back(startPosMoves[startPosInd + 3]); + + staticStartPosMoves.erase(staticStartPosMoves.begin() + staticStartPosInd, staticStartPosMoves.begin() + staticStartPosInd + 4); break; } } + for (const auto recognizedMove : startPosMoves) { + if (refinedStartPosMoves.size() == startPosMovesSize) { + break; + } + if (!std::any_of(refinedStartPosMoves.begin(), refinedStartPosMoves.end(), [&](const Move& m) { + return movesEqual(m, recognizedMove); + })) { + refinedStartPosMoves.emplace_back(recognizedMove); + } + } + + startPosMoves = refinedStartPosMoves; + if (remainingMoves != nullptr) { - for (const auto recognizedMove : recognizedCrossesMoves) { - bool removed = false; + for (const auto recognizedMove : startPosMoves) { + bool remainingMoveIsRemoved = false; for(auto it = remainingMoves->begin(); it != remainingMoves->end(); ++it) { - if (it->loc == recognizedMove.loc && it->pla == recognizedMove.pla) { + if (movesEqual(*it, recognizedMove)) { remainingMoves->erase(it); - removed = true; + remainingMoveIsRemoved = true; break; } } - assert(removed); + assert(remainingMoveIsRemoved); } } - result = expectedStartPos; + randomized = !staticStartPosMoves.empty(); + resultStartPos = expectedStartPos; }; - switch (recognizedCrossesMoves.size()) { - case 4: - detectRandomization(START_POS_CROSS); - break; - case 8: - detectRandomization(START_POS_CROSS_2); - break; - case 16: - detectRandomization(START_POS_CROSS_4); - break; - default:; - break; + if (const auto recognizedCrossesMovesSize = startPosMoves.size(); recognizedCrossesMovesSize < 4) { + finishRecognition(START_POS_EMPTY); + } else if (recognizedCrossesMovesSize < 8) { + finishRecognition(START_POS_CROSS); + } else if (recognizedCrossesMovesSize < 16) { + finishRecognition(START_POS_CROSS_2); + } else { + finishRecognition(START_POS_CROSS_4); } - return result; + return resultStartPos; } const Hash128 Rules::ZOBRIST_KO_RULE_HASH[4] = { diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 89b744dc5..6674cf34b 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -16,7 +16,6 @@ struct Rules { static constexpr int START_POS_CROSS = 1; static constexpr int START_POS_CROSS_2 = 2; static constexpr int START_POS_CROSS_4 = 3; - static constexpr int START_POS_CUSTOM = 4; int startPos; // Enables random shuffling of start pos. Currently, it works only for CROSS_4 @@ -118,14 +117,18 @@ struct Rules { * @param placementMoves placement moves that we are trying to recognize. * @param x_size size of field * @param y_size size of field + * @param startPosMoves moves for a recognized pattern. It's empty if recognition is failed. * @param randomized if we recognize a start pos, but it doesn't match the strict position, set it up to `true` - * @param remainingMoves Holds moves that remain after start pos recognition, useful for SGF handling. + * @param remainingMoves represents moves that remain after start pos recognition (@param placementMoves - @param startPosMoves), * If it's null (default), then it's assumed that all placement moves are used in the recognized start pos. + * @return recognized type of start pos. If the @param placementMoves don't match any known patter, then + * it returns empty pos with @param remainingMoves == @param placementMoves */ - static int tryRecognizeStartPos( + static int recognizeStartPos( const std::vector& placementMoves, int x_size, int y_size, + std::vector& startPosMoves, bool& randomized, std::vector* remainingMoves = nullptr); @@ -175,12 +178,10 @@ struct Rules { startPosIdToName[START_POS_CROSS] = "CROSS"; startPosIdToName[START_POS_CROSS_2] = "CROSS_2"; startPosIdToName[START_POS_CROSS_4] = "CROSS_4"; - startPosIdToName[START_POS_CUSTOM] = "CUSTOM"; startPosNameToId["EMPTY"] = START_POS_EMPTY; startPosNameToId["CROSS"] = START_POS_CROSS; startPosNameToId["CROSS_2"] = START_POS_CROSS_2; startPosNameToId["CROSS_4"] = START_POS_CROSS_4; - startPosNameToId["CUSTOM"] = START_POS_CUSTOM; } } diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 2677c9d67..67f275632 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -729,8 +729,9 @@ Board SymmetryHelpers::getSymBoard(const Board& board, int symmetry) { const int y = Location::getY(loc, board.x_size); sym_start_pos_moves.emplace_back(getSymLoc(x, y), start_pos_move.pla); } + vector startPosMoves; bool randomized; - symRules.startPos = Rules::tryRecognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, randomized); + symRules.startPos = Rules::recognizeStartPos(sym_start_pos_moves, sym_x_size, sym_y_size, startPosMoves, randomized); symRules.startPosIsRandom = randomized; } diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp index 2b97d8f8b..aa6cf63ad 100644 --- a/cpp/tests/testdotsextra.cpp +++ b/cpp/tests/testdotsextra.cpp @@ -116,10 +116,12 @@ SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); auto rulesAfterTransformation = originalRules; rulesAfterTransformation.startPosIsRandom = true; auto expectedBoard = Board(4, 5, rulesAfterTransformation); - expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 2, expectedBoard.x_size), P_WHITE, true); - expectedBoard.setStoneFailIfNoLibs(Location::getLoc(1, 3, expectedBoard.x_size), P_BLACK, true); - expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 3, expectedBoard.x_size), P_WHITE, true); - expectedBoard.setStoneFailIfNoLibs(Location::getLoc(2, 2, expectedBoard.x_size), P_BLACK, true); + expectedBoard.setStonesFailIfNoLibs({ + Move(Location::getLoc(2, 2, expectedBoard.x_size), P_BLACK), + Move(Location::getLoc(2, 3, expectedBoard.x_size), P_WHITE), + Move(Location::getLoc(1, 3, expectedBoard.x_size), P_BLACK), + Move(Location::getLoc(1, 2, expectedBoard.x_size), P_WHITE), + }, true); expectedBoard.playMoveAssumeLegal(Location::getLoc(1, 1, expectedBoard.x_size), P_BLACK); expect("Dots symmetry with start pos", Board::toStringSimple(rotatedBoard), Board::toStringSimple(expectedBoard)); diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index 4ddfa7d27..c44c02214 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -28,9 +28,7 @@ void checkStartPos(const string& description, const int startPos, const bool sta auto board = Board(x_size, y_size, Rules(startPos, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); board.setStartPos(DOTS_RANDOM); - for (const XYMove& extraMove : extraMoves) { - board.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, board.x_size), extraMove.player); - } + playXYMovesAssumeLegal(board, extraMoves); std::ostringstream oss; Board::printBoard(oss, board, Board::NULL_LOC, nullptr, false); @@ -42,7 +40,38 @@ void checkStartPos(const string& description, const int startPos, const bool sta writeToSgfAndCheckStartPosFromSgfProp(startPos, startPosIsRandom, board); } -void checkStartPosRecognition(const string& description, const int expectedStartPos, const int startPosIsRandom, const string& inputBoard) { +void checkRecognition(const vector& xyMoves, const int x_size, const int y_size, + const int expectedStartPos, + const vector& expectedStartMoves, + const bool expectedRandomized, + const vector& expectedRemainingMoves) { + + auto moves = vector(); + moves.reserve(xyMoves.size()); + for (const auto xyMove : xyMoves) { + moves.push_back(xyMove.toMove(x_size)); + } + + vector actualStartPosMoves; + bool actualRandomized; + vector actualRemainingMoves; + + testAssert(expectedStartPos == Rules::recognizeStartPos(moves, x_size, y_size, actualStartPosMoves, actualRandomized, &actualRemainingMoves)); + + testAssert(expectedStartMoves.size() == actualStartPosMoves.size()); + for (size_t i = 0; i < expectedStartMoves.size(); ++i) { + testAssert(movesEqual(expectedStartMoves[i].toMove(x_size), actualStartPosMoves[i])); + } + + testAssert(expectedRandomized == actualRandomized); + + testAssert(expectedRemainingMoves.size() == actualRemainingMoves.size()); + for (size_t i = 0; i < expectedRemainingMoves.size(); ++i) { + testAssert(movesEqual(expectedRemainingMoves[i].toMove(x_size), actualRemainingMoves[i])); + } +} + +void checkStartPosRecognition(const string& description, const int expectedStartPos, const bool startPosIsRandom, const string& inputBoard) { const Board board = parseDotsField(inputBoard, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, {}); cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; @@ -52,8 +81,9 @@ void checkStartPosRecognition(const string& description, const int expectedStart void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { const auto generatedMoves = Rules::generateStartPos(startPos, startPosIsRandom ? &DOTS_RANDOM : nullptr, 39, 32); + vector actualStartPosMoves; bool actualRandomized; - testAssert(startPos == Rules::tryRecognizeStartPos(generatedMoves, 39, 32, actualRandomized)); + testAssert(startPos == Rules::recognizeStartPos(generatedMoves, 39, 32, actualStartPosMoves, actualRandomized)); // We can't reliably check in case of randomization is not detected because random generator can // generate static poses in rare cases. if (actualRandomized) { @@ -80,7 +110,7 @@ void Tests::runDotsStartPosTests() { 1 . . X . )", {XYMove(2, 3, P_BLACK)}); - checkStartPosRecognition("Not enough dots for cross", Rules::START_POS_CUSTOM, false, R"( + checkStartPosRecognition("Empty start pos with three extra moves", Rules::START_POS_EMPTY, false, R"( .... .xo. .o.. @@ -166,6 +196,107 @@ void Tests::runDotsStartPosTests() { 1 . . . . . . . )"); + checkStartPos("Double rand cross on triple cross", Rules::START_POS_CROSS_2, false, 10, 4, R"( + 1 2 3 4 5 6 7 8 9 10 + 4 . . . . . . . . . . + 3 . . . X O O X . X O + 2 . . . O X X O . O X + 1 . . . . . . . . . . +)", {XYMove(8, 1, P_BLACK), XYMove(9, 1, P_WHITE), XYMove(9, 2, P_BLACK), XYMove(8, 2, P_WHITE)}); + + const vector expectedRemainingMovesForDoubleCross = { + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }; + + // Double cross exactly matches static double cross (not randomized) + checkRecognition({ + XYMove(3, 1, P_BLACK), + XYMove(4, 1, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(3, 2, P_WHITE), + + XYMove(5, 1, P_WHITE), + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }, 10, 4, Rules::START_POS_CROSS_2, { + XYMove(3, 1, P_BLACK), + XYMove(4, 1, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(3, 2, P_WHITE), + + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + XYMove(5, 1, P_WHITE), + }, false, expectedRemainingMovesForDoubleCross +); + + // Double cross partially matches static double cross (randomized) + checkRecognition({ + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(5, 1, P_WHITE), + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) +}, 10, 4, Rules::START_POS_CROSS_2, +{ + XYMove(6, 1, P_BLACK), + XYMove(6, 2, P_WHITE), + XYMove(5, 2, P_BLACK), + XYMove(5, 1, P_WHITE), + + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), +}, true, expectedRemainingMovesForDoubleCross); + + // Double cross do not match static cross completely (randomized) + checkRecognition({ + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(4, 1, P_WHITE), + XYMove(5, 1, P_BLACK), + XYMove(5, 2, P_WHITE), + XYMove(4, 2, P_BLACK), + + XYMove(8, 1, P_BLACK), + XYMove(9, 1, P_WHITE), + XYMove(9, 2, P_BLACK), + XYMove(8, 2, P_WHITE) + }, 10, 4, Rules::START_POS_CROSS_2, { + XYMove(2, 1, P_BLACK), + XYMove(3, 1, P_WHITE), + XYMove(3, 2, P_BLACK), + XYMove(2, 2, P_WHITE), + + XYMove(5, 1, P_BLACK), + XYMove(5, 2, P_WHITE), + XYMove(4, 2, P_BLACK), + XYMove(4, 1, P_WHITE) + }, true, expectedRemainingMovesForDoubleCross); + checkStartPos("Double cross on standard size", Rules::START_POS_CROSS_2, false, 39, 32, R"( 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 32 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . diff --git a/cpp/tests/testdotsutils.cpp b/cpp/tests/testdotsutils.cpp index 5ac0aaa8f..f92686338 100644 --- a/cpp/tests/testdotsutils.cpp +++ b/cpp/tests/testdotsutils.cpp @@ -2,6 +2,10 @@ using namespace std; +Move XYMove::toMove(const int x_size) const { + return Move(Location::getLoc(x, y, x_size), player); +} + Board parseDotsFieldDefault(const string& input, const vector& extraMoves) { return parseDotsField(input, Rules::DEFAULT_DOTS.startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, extraMoves); } @@ -28,8 +32,12 @@ Board parseDotsField(const string& input, const bool startPosIsRandom, const boo } Board result = Board::parseBoard(xSize, ySize, input, Rules(Rules::START_POS_EMPTY, startPosIsRandom, suicide, captureEmptyBases, freeCapturedDots)); - for(const XYMove& extraMove : extraMoves) { - result.playMoveAssumeLegal(Location::getLoc(extraMove.x, extraMove.y, result.x_size), extraMove.player); - } + playXYMovesAssumeLegal(result, extraMoves); return result; +} + +void playXYMovesAssumeLegal(Board& board, const vector& moves) { + for(const XYMove& move : moves) { + board.playMoveAssumeLegal(Location::getLoc(move.x, move.y, board.x_size), move.player); + } } \ No newline at end of file diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index d6f372496..9662a0327 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -16,6 +16,8 @@ struct XYMove { [[nodiscard]] std::string toString() const { return "(" + to_string(x) + "," + to_string(y) + "," + PlayerIO::colorToChar(player) + ")"; } + + Move toMove(int x_size) const; }; struct BoardWithMoveRecords { @@ -66,3 +68,5 @@ Board parseDotsFieldDefault(const string& input, const vector& extraMove Board parseDotsField(const string& input, bool startPosIsRandom, bool suicide, bool captureEmptyBases, bool freeCapturedDots, const vector& extraMoves); +void playXYMovesAssumeLegal(Board& board, const vector& moves); + diff --git a/cpp/tests/testsgf.cpp b/cpp/tests/testsgf.cpp index fe6882ddd..a5a658674 100644 --- a/cpp/tests/testsgf.cpp +++ b/cpp/tests/testsgf.cpp @@ -30,8 +30,9 @@ void Tests::runSgfTests() { const Board& board = hist.initialBoard; bool randomized; + vector startPosMoves; vector remainingPlacementMoves; - const int recognizedStartPos = Rules::tryRecognizeStartPos(sgf->placements, board.x_size, board.y_size, randomized, &remainingPlacementMoves); + const int recognizedStartPos = Rules::recognizeStartPos(sgf->placements, board.x_size, board.y_size, startPosMoves, randomized, &remainingPlacementMoves); testAssert(recognizedStartPos == rules.startPos); testAssert(randomized == rules.startPosIsRandom); @@ -745,7 +746,6 @@ xSize 9 ySize 9 depth 2 komi 7.5 -startPos CUSTOM placements X B7 X D7 @@ -774,7 +774,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koPOSITIONALscoreAREAtaxNONEstartPosCUSTOMsui1komi7.5 +Rules koPOSITIONALscoreAREAtaxNONEsui1komi7.5 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -803,7 +803,7 @@ Encore phase 0 Turns this phase 0 Approx valid turns this phase 0 Approx consec valid turns this game 0 -Rules koPOSITIONALscoreAREAtaxNONEstartPosCUSTOMsui1komi7.5 +Rules koPOSITIONALscoreAREAtaxNONEsui1komi7.5 Ko recap block hash 00000000000000000000000000000000 White bonus score 0 White handicap bonus score 0 @@ -1179,7 +1179,7 @@ xSize 5 ySize 5 komi 12.5 hasRules true -rules koSIMPLEscoreTERRITORYtaxSEKIstartPosCUSTOMsui0komi12.5 +rules koSIMPLEscoreTERRITORYtaxSEKIsui0komi12.5 handicapValue 5 sgfWinner Black firstPlayerColor X From 20f237d39acc59d35c2a8262e5394a77cac3e7f9 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 16:43:44 +0100 Subject: [PATCH 31/42] [GTP] Introduce `get_moves` and `get_position` commands --- cpp/book/book.cpp | 2 +- cpp/command/contribute.cpp | 8 +++--- cpp/command/gtp.cpp | 26 +++++++++++++++++- cpp/dataio/sgf.cpp | 17 ++++++------ cpp/game/board.cpp | 51 +++++++++++++++++++++++++----------- cpp/game/board.h | 5 ++-- cpp/game/boardhistory.cpp | 2 +- cpp/game/common.h | 5 ++++ cpp/search/searchparams.cpp | 2 +- cpp/search/searchresults.cpp | 2 +- 10 files changed, 85 insertions(+), 35 deletions(-) diff --git a/cpp/book/book.cpp b/cpp/book/book.cpp index 668f08a1b..04dd18c2e 100644 --- a/cpp/book/book.cpp +++ b/cpp/book/book.cpp @@ -2922,7 +2922,7 @@ void Book::saveToFile(const string& fileName) const { json nodeData = json::object(); if(bookVersion >= 2) { nodeData["id"] = nodeIdx; - nodeData["pla"] = PlayerIO::playerToStringShort(node->pla); + nodeData["pla"] = PlayerIO::playerToStringShort(node->pla, initialRules.isDots); nodeData["syms"] = node->symmetries; nodeData["wl"] = roundDouble(node->thisValuesNotInBook.winLossValue, 100000000); nodeData["sM"] = roundDouble(node->thisValuesNotInBook.scoreMean, 1000000); diff --git a/cpp/command/contribute.cpp b/cpp/command/contribute.cpp index ec0c3a9a0..cadbcec70 100644 --- a/cpp/command/contribute.cpp +++ b/cpp/command/contribute.cpp @@ -242,7 +242,7 @@ static void runAndUploadSingleGame( json ret; // unique to this output ret["gameId"] = gameIdString; - ret["move"] = json::array({PlayerIO::playerToStringShort(pla), Location::toString(moveLoc, board)}); + ret["move"] = json::array({PlayerIO::playerToStringShort(pla, board.isDots()), Location::toString(moveLoc, board)}); ret["blackPlayer"] = botSpecB.botName; ret["whitePlayer"] = botSpecW.botName; @@ -253,7 +253,7 @@ static void runAndUploadSingleGame( json moves = json::array(); for(auto move: hist.moveHistory) { - moves.push_back(json::array({PlayerIO::playerToStringShort(move.pla), Location::toString(move.loc, board)})); + moves.push_back(json::array({PlayerIO::playerToStringShort(move.pla, board.isDots()), Location::toString(move.loc, board)})); } ret["moves"] = moves; @@ -264,11 +264,11 @@ static void runAndUploadSingleGame( Loc loc = Location::getLoc(x, y, initialBoard.x_size); Player locOwner = initialBoard.colors[loc]; if(locOwner != C_EMPTY) - initialStones.push_back(json::array({PlayerIO::playerToStringShort(locOwner), Location::toString(loc, initialBoard)})); + initialStones.push_back(json::array({PlayerIO::playerToStringShort(locOwner, board.isDots()), Location::toString(loc, initialBoard)})); } } ret["initialStones"] = initialStones; - ret["initialPlayer"] = PlayerIO::playerToStringShort(hist.initialPla); + ret["initialPlayer"] = PlayerIO::playerToStringShort(hist.initialPla, board.isDots()); ret["initialTurnNumber"] = hist.initialTurnNumber; // Usual analysis response fields diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 8a296ac3a..1bff816cd 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -20,6 +20,9 @@ using namespace std; +const string GET_MOVES_COMMAND = "get_moves"; +const string GET_POSITION_COMMAND = "get_position"; + static const vector knownCommands = { //Basic GTP commands "protocol_version", @@ -36,10 +39,12 @@ static const vector knownCommands = { "clear_board", "set_position", + GET_POSITION_COMMAND, "komi", //GTP extension - get KataGo's current komi setting "get_komi", "play", + GET_MOVES_COMMAND, "undo", //GTP extension - specify rules @@ -1180,7 +1185,7 @@ struct GTPEngine { response = "genmove returned null location or illegal move"; ostringstream sout; sout << "genmove null location or illegal move!?!" << "\n"; - const auto rootBoard = search->getRootBoard(); + const auto& rootBoard = search->getRootBoard(); sout << rootBoard << "\n"; sout << "Pla: " << PlayerIO::playerToString(pla,rootBoard.isDots()) << "\n"; sout << "MoveLoc: " << Location::toString(moveLoc,search->getRootBoard()) << "\n"; @@ -1925,6 +1930,14 @@ optional parseMovesSequence(const vector& pieces, const Boa return response; } +string printMoves(const vector& moves, const Board& board) { + std::ostringstream builder; + for (const auto move : moves) { + builder << PlayerIO::playerToStringShort(move.pla, board.isDots()) << " " << Location::toString(move.loc, board) << " "; + } + return builder.str(); +} + int MainCmds::gtp(const vector& args) { Board::initHash(); ScoreValue::initTables(); @@ -3004,6 +3017,17 @@ int MainCmds::gtp(const vector& args) { } } + else if (command == GET_MOVES_COMMAND) { + const BoardHistory& history = engine->bot->getRootHist(); + response = printMoves(history.moveHistory, engine->bot->getRootBoard()); + } + + else if (command == GET_POSITION_COMMAND) { + const auto& rootBoard = engine->bot->getRootBoard(); + // TODO: fix for handicap start moves + response = printMoves(rootBoard.start_pos_moves, rootBoard); + } + else if(command == "undo") { bool suc = engine->undo(); if(!suc) { diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index aae231ad9..4da0af4c4 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1135,24 +1135,25 @@ std::set Sgf::readExcludes(const vector& files) { string Sgf::PositionSample::toJsonLine(const Sgf::PositionSample& sample) { json data; - if (sample.board.rules.isDots) { + const Board& board = sample.board; + if (board.rules.isDots) { data[DOTS_KEY] = "true"; } - data["xSize"] = sample.board.x_size; - data["ySize"] = sample.board.y_size; - data["board"] = Board::toStringSimple(sample.board,'/'); - data["nextPla"] = PlayerIO::playerToStringShort(sample.nextPla); + data["xSize"] = board.x_size; + data["ySize"] = board.y_size; + data["board"] = Board::toStringSimple(board,'/'); + data["nextPla"] = PlayerIO::playerToStringShort(sample.nextPla, board.isDots()); vector moveLocs; vector movePlas; for(size_t i = 0; i 0) data["metadata"] = sample.metadata; diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 31beeba95..eac1bfbd7 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -693,24 +693,28 @@ bool Board::isNonPassAliveSelfConnection(Loc loc, Player pla, Color* passAliveAr bool Board::isStartPos() const { int startBoardNumBlackStones, startBoardNumWhiteStones; - numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, false); + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, false); return startBoardNumBlackStones == 0 && startBoardNumWhiteStones == 0; } int Board::numStonesOnBoard() const { int startBoardNumBlackStones, startBoardNumWhiteStones; - numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, true); + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, true); return startBoardNumBlackStones + startBoardNumWhiteStones; } int Board::numPlaStonesOnBoard(Player pla) const { int startBoardNumBlackStones, startBoardNumWhiteStones; - numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, true); + getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, true); return pla == C_BLACK ? startBoardNumBlackStones : startBoardNumWhiteStones; } -void Board::numStartBlackWhiteStones(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, +vector Board::getCurrentMoves( + int& startBoardNumBlackStones, + int& startBoardNumWhiteStones, const bool includeStartLocs) const { + + vector stones; startBoardNumBlackStones = 0; startBoardNumWhiteStones = 0; @@ -718,20 +722,35 @@ void Board::numStartBlackWhiteStones(int& startBoardNumBlackStones, int& startBo set startLocs; for(auto move: start_pos_moves) { startLocs.insert(move.loc); + if (includeStartLocs) { + // Fill start pos stones as in priority + stones.emplace_back(move); + } } for(int y = 0; y < y_size; y++) { for(int x = 0; x < x_size; x++) { if(const Loc loc = Location::getLoc(x, y, x_size)) { - if (includeStartLocs || startLocs.count(loc) == 0) { - if(const Color color = getPlacedColor(loc); color == C_BLACK) + bool isStartPosLoc = startLocs.count(loc) > 0; + if (includeStartLocs || !isStartPosLoc) { + if(const Color color = getPlacedColor(loc); color == C_BLACK) { startBoardNumBlackStones += 1; - else if(color == C_WHITE) + if (!isStartPosLoc) { + stones.emplace_back(loc, C_BLACK); + } + } + else if(color == C_WHITE) { startBoardNumWhiteStones += 1; + if (!isStartPosLoc) { + stones.emplace_back(loc, C_WHITE); + } + } } } } } + + return stones; } bool Board::setStone(Loc loc, Color color) @@ -2569,28 +2588,28 @@ char PlayerIO::stateToChar(const State s, const bool isDots) { string PlayerIO::playerToString(const Color c, const bool isDots) { switch(c) { - case C_BLACK: return !isDots ? "Black" : "Player1"; - case C_WHITE: return !isDots ? "White" : "Player2"; + case C_BLACK: return !isDots ? "Black" : PLAYER1; + case C_WHITE: return !isDots ? "White" : PLAYER2; case C_EMPTY: return "Empty"; default: return "Wall"; } } -string PlayerIO::playerToStringShort(Color c) +string PlayerIO::playerToStringShort(const Color p, const bool isDots) { - switch(c) { - case C_BLACK: return "B"; - case C_WHITE: return "W"; + switch(p) { + case C_BLACK: return !isDots ? "B" : PLAYER1_SHORT; + case C_WHITE: return !isDots ? "W" : PLAYER2_SHORT; case C_EMPTY: return "E"; - default: return "Wall"; + default: return ""; } } bool PlayerIO::tryParsePlayer(const string& s, Player& pla) { - if(const string str = Global::toLower(s); str == "black" || str == "b" || str == "blue" || str == "p1") { + if(const string str = Global::toUpper(s); str == "BLACK" || str == "B" || str == "BLUE" || str == PLAYER1_SHORT) { pla = P_BLACK; return true; - } else if(str == "white" || str == "w" || str == "red" || str == "r" || str == "p2") { + } else if(str == "WHITE" || str == "W" || str == "RED" || str == "R" || str == PLAYER2_SHORT) { pla = P_WHITE; return true; } diff --git a/cpp/game/board.h b/cpp/game/board.h index df14297f6..5270d12c8 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -54,7 +54,7 @@ bool isTerritory(State s); namespace PlayerIO { char colorToChar(Color c); char stateToChar(State s, bool isDots); - std::string playerToStringShort(Player p); + std::string playerToStringShort(Color p, bool isDots = false); std::string playerToString(Player p, bool isDots); bool tryParsePlayer(const std::string& s, Player& pla); Player parsePlayer(const std::string& s); @@ -298,7 +298,8 @@ struct Board //Count the number of stones on the board int numStonesOnBoard() const; int numPlaStonesOnBoard(Player pla) const; - void numStartBlackWhiteStones(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, bool includeStartLocs) const; + std::vector + getCurrentMoves(int& startBoardNumBlackStones, int& startBoardNumWhiteStones, bool includeStartLocs) const; //Get a hash that combines the position of the board with simple ko prohibition and a player to move. Hash128 getSitHashWithSimpleKo(Player pla) const; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index b772cce84..1f918c8b4 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -398,7 +398,7 @@ void BoardHistory::setOverrideNumHandicapStones(int n) { static int numHandicapStonesOnBoardHelper(const Board& board, const int blackNonPassTurnsToStart) { int startBoardNumBlackStones, startBoardNumWhiteStones; - board.numStartBlackWhiteStones(startBoardNumBlackStones, startBoardNumWhiteStones, false); + board.getCurrentMoves(startBoardNumBlackStones, startBoardNumWhiteStones, false); //If we set up in a nontrivial position, then consider it a non-handicap game. if(startBoardNumWhiteStones != 0) diff --git a/cpp/game/common.h b/cpp/game/common.h index d6780850d..562b12a52 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -16,6 +16,11 @@ const std::string START_POSES_ARE_RANDOM_KEY = "startPosesAreRandom"; const std::string BLACK_SCORE_IF_WHITE_GROUNDS_KEY = "blackScoreIfWhiteGrounds"; const std::string WHITE_SCORE_IF_BLACK_GROUNDS_KEY = "whiteScoreIfBlackGrounds"; +const std::string PLAYER1 = "Player1"; +const std::string PLAYER2 = "Player2"; +const std::string PLAYER1_SHORT = "P1"; +const std::string PLAYER2_SHORT = "P2"; + // Player typedef int8_t Player; static constexpr Player P_BLACK = 1; diff --git a/cpp/search/searchparams.cpp b/cpp/search/searchparams.cpp index 4096e1373..26206a70c 100644 --- a/cpp/search/searchparams.cpp +++ b/cpp/search/searchparams.cpp @@ -456,7 +456,7 @@ json SearchParams::changeableParametersToJson() const { // Special handling in GTP ret["playoutDoublingAdvantage"] = playoutDoublingAdvantage; - ret["playoutDoublingAdvantagePla"] = PlayerIO::playerToStringShort(playoutDoublingAdvantagePla); + ret["playoutDoublingAdvantagePla"] = PlayerIO::playerToStringShort(playoutDoublingAdvantagePla, false); // TODO: Fix for Dots // Special handling in GTP // ret["avoidRepeatedPatternUtility"] = avoidRepeatedPatternUtility; diff --git a/cpp/search/searchresults.cpp b/cpp/search/searchresults.cpp index 1c1872296..9fc5e0b8a 100644 --- a/cpp/search/searchresults.cpp +++ b/cpp/search/searchresults.cpp @@ -2155,7 +2155,7 @@ bool Search::getAnalysisJson( } rootInfo["thisHash"] = Global::uint64ToHexString(thisHash.hash1) + Global::uint64ToHexString(thisHash.hash0); rootInfo["symHash"] = Global::uint64ToHexString(symHash.hash1) + Global::uint64ToHexString(symHash.hash0); - rootInfo["currentPlayer"] = PlayerIO::playerToStringShort(rootPla); + rootInfo["currentPlayer"] = PlayerIO::playerToStringShort(rootPla, board.isDots()); ret["rootInfo"] = rootInfo; } From 689849c8caff5c8028e943438d4a544625b7f9af Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Mon, 13 Oct 2025 22:12:54 +0200 Subject: [PATCH 32/42] [GTP] Introduce `get_boardsize` command, `undo` supports multiple moves --- cpp/command/gtp.cpp | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 1bff816cd..5a2cf0fc2 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -22,6 +22,7 @@ using namespace std; const string GET_MOVES_COMMAND = "get_moves"; const string GET_POSITION_COMMAND = "get_position"; +const string GET_BOARDSIZE = "get_boardsize"; static const vector knownCommands = { //Basic GTP commands @@ -35,6 +36,7 @@ static const vector knownCommands = { //GTP extension - specify "boardsize X:Y" or "boardsize X Y" for non-square sizes //rectangular_boardsize is an alias for boardsize, intended to make it more evident that we have such support "boardsize", + GET_BOARDSIZE, "rectangular_boardsize", "clear_board", @@ -2375,6 +2377,13 @@ int MainCmds::gtp(const vector& args) { } } + else if (command == GET_BOARDSIZE) { + const auto& rootBoard = engine->bot->getRootBoard(); + response = Global::intToString(rootBoard.x_size) + (rootBoard.x_size == rootBoard.y_size + ? "" + : ":" + Global::intToString(rootBoard.y_size)); + } + else if(command == "clear_board") { maybeSaveAvoidPatterns(false); if(autoAvoidPatterns && shouldReloadAutoAvoidPatterns) { @@ -3029,10 +3038,22 @@ int MainCmds::gtp(const vector& args) { } else if(command == "undo") { - bool suc = engine->undo(); - if(!suc) { - responseIsError = true; - response = "cannot undo"; + int undoCount = 1; + if (pieces.size() > 0) { + if (!Global::tryStringToInt(pieces[0], undoCount) || undoCount < 0) { + responseIsError = true; + response = "Expected nonnegative integer for undo count"; + } + } + + if (!responseIsError) { + for (int i = 0; i < undoCount; i++) { + if(!engine->undo()) { + responseIsError = true; + response = "cannot undo"; + break; + } + } } } From 58cc24f72c5b7af6503408d8f3f9907c10450b21 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Mon, 13 Oct 2025 22:17:01 +0200 Subject: [PATCH 33/42] Fix the compilation error with a localpattern --- cpp/search/localpattern.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/search/localpattern.cpp b/cpp/search/localpattern.cpp index 5124182bf..d7bf0672d 100644 --- a/cpp/search/localpattern.cpp +++ b/cpp/search/localpattern.cpp @@ -52,10 +52,11 @@ Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) con Hash128 hash = zobristPla[pla]; if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - vector captures; vector bases; if (board.isDots()) { - // TODO: implement more faster version of `board.calculateOneMoveCaptureAndBasePositionsForDots(true, captures, bases);` + vector captures; + // TODO: implement fast version for Dots + // board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); } const int dxi = 1; @@ -91,10 +92,11 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p Hash128 hash = zobristPla[symPla]; if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { - vector captures; vector bases; if (board.isDots()) { - board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); + vector captures; + // TODO: implement fast version for Dots + // board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); } const int dxi = 1; @@ -158,8 +160,7 @@ void LocalPatternHasher::updateHash( bool addAtariHash = false; if(board.isDots()) { - //addAtariHash = bases[loc] != C_EMPTY; - addAtariHash = false; // TODO: implement for Dots + addAtariHash = !bases.empty() && bases[loc] != C_EMPTY; } else { addAtariHash = (colorAtLoc == P_BLACK || colorAtLoc == P_WHITE) && board.getNumLiberties(loc) == 1; } From 647d523bd302b7dfb84c748a11c243bb48a78dd4 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Mon, 13 Oct 2025 22:22:56 +0200 Subject: [PATCH 34/42] Add configs and fix scripts to make it robust for seltraining loop It's temporary because the changes are actually dirty --- cpp/CMakeLists.txt | 2 +- cpp/configs/gtp_example_dots.cfg | 744 ++++++++++++++++++++++ cpp/configs/training/gatekeeper1_dots.cfg | 117 ++++ cpp/configs/training/selfplay1_dots.cfg | 196 ++++++ python/selfplay/synchronous_loop.sh | 14 +- python/selfplay/train.sh | 2 +- 6 files changed, 1066 insertions(+), 9 deletions(-) create mode 100644 cpp/configs/gtp_example_dots.cfg create mode 100644 cpp/configs/training/gatekeeper1_dots.cfg create mode 100644 cpp/configs/training/selfplay1_dots.cfg diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bfc885219..be0f6f1e0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -458,7 +458,7 @@ endif() if(DOTS_GAME) target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_X=39) - target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_Y=36) + target_compile_definitions(katago PRIVATE COMPILE_MAX_BOARD_LEN_Y=39) if(USE_BIGGER_BOARDS_EXPENSIVE) message(SEND_ERROR "USE_BIGGER_BOARDS_EXPENSIVE is not yet supported for Dots Game") endif() diff --git a/cpp/configs/gtp_example_dots.cfg b/cpp/configs/gtp_example_dots.cfg new file mode 100644 index 000000000..be47f7387 --- /dev/null +++ b/cpp/configs/gtp_example_dots.cfg @@ -0,0 +1,744 @@ +# Configuration for KataGo C++ GTP engine + +# Run the program using: `./katago.exe gtp` + +# In this example config, when a parameter is given as a commented out value, +# that value also is the default value, unless described otherwise. You can +# uncomment it (remove the pound sign) and change it if you want. + +# =========================================================================== +# Running on an online server or in a real tournament or match +# =========================================================================== +# If you plan to run online or in a tournament, read through the "Rules" +# section below for proper handling of komi, handicaps, end-of-game cleanup, +# and other details. + +# =========================================================================== +# Notes about performance and memory usage +# =========================================================================== +# Important: For good performance, you will very likely want to tune the +# "numSearchThreads" parameter in the Search limits section below! Run +# "./katago benchmark" to test KataGo and to suggest a reasonable value +# of this parameter. + +# For multi-GPU systems, read "OpenCL GPU settings" or "CUDA GPU settings". +# +# When using OpenCL, verify that KataGo picks the correct device! Some systems +# may have both an Intel CPU OpenCL and GPU OpenCL. If # KataGo picks the wrong +# one, correct this by specifying "openclGpuToUse". +# +# Consider adjusting "maxVisits", "ponderingEnabled", "resignThreshold", and +# other parameters depending on your intended usage. + +# =========================================================================== +# Command-line usage +# =========================================================================== +# All of the below values may be set or overridden via command-line arguments: +# +# -override-config KEY=VALUE,KEY=VALUE,... + +# =========================================================================== +# Logs and files +# =========================================================================== +# This section defines where and what logging information is produced. + +# Each run of KataGo will log to a separate file in this dir. +# This is the default. +logDir = gtp_logs +# Uncomment and specify this instead of logDir to write separate dated subdirs +# logDirDated = gtp_logs +# Uncomment and specify this instead of logDir to log to only a single file +# logFile = gtp.log + +# Logging options +logAllGTPCommunication = true +logSearchInfo = true +logSearchInfoForChosenMove = false +logToStderr = false + +# KataGo will display some info to stderr on GTP startup +# Uncomment the next line and set it to false to suppress that and remain silent +# startupPrintMessageToStderr = true + +# Write information to stderr, for use in things like malkovich chat to OGS. +# ogsChatToStderr = false + +# Uncomment and set this to a directory to override where openCLTuner files +# and other cached data is written. By default it saves into a subdir of the +# current directory on windows, and a subdir of ~/.katago on Linux. +# homeDataDir = PATH_TO_DIRECTORY + +# =========================================================================== +# Analysis +# =========================================================================== +# This section configures analysis settings. +# +# The maximum number of moves after the first move displayed in variations +# from analysis commands like kata-analyze or lz-analyze. +# analysisPVLen = 15 + +# Report winrates for chat and analysis as (BLACK|WHITE|SIDETOMOVE). +# Most GUIs and analysis tools will expect SIDETOMOVE. +# reportAnalysisWinratesAs = SIDETOMOVE + +# Extra noise for wider exploration. Large values will force KataGo to +# analyze a greater variety of moves than it normally would. +# An extreme value like 1 distributes playouts across every move on the board, +# even very bad moves. +# Affects analysis only, does not affect play. +# analysisWideRootNoise = 0.04 + +# Try to limit the effect of possible bad or bogus move sequences in the +# history leading to this position from affecting KataGo's move predictions. +# analysisIgnorePreRootHistory = true + +# =========================================================================== +# Rules +# =========================================================================== +# This section configures the scoring and playing rules. Rules can also be +# changed mid-run by issuing custom GTP commands. +# +# See https://lightvector.github.io/KataGo/rules.html for rules details. +# +# See https://github.com/lightvector/KataGo/blob/master/docs/GTP_Extensions.md +# for GTP commands. + +# Specify the rules as a string. +# Some legal values include: +# chinese, japanese, korean, aga, chinese-ogs, new-zealand, stone-scoring, +# ancient-territory, bga, aga-button +# +# For some human rulesets that require complex adjudication in tricky cases +# (e.g. japanese, korean) KataGo may not precisely match the ruleset in such +# cases but will do its best. + +dots = true +rules = dots +defaultBoardXSize = 14 +defaultBoardYSize = 14 + +# By default, the "rules" parameter is used, but if you comment it out and +# uncomment one option in each of the sections below, you can specify an +# arbitrary combination of individual rules. + +# koRule = SIMPLE # Simple ko rules (triple ko = no result) +# koRule = POSITIONAL # Positional superko +# koRule = SITUATIONAL # Situational superko + +# scoringRule = AREA # Area scoring +# scoringRule = TERRITORY # Territory scoring (special computer-friendly territory rules) + +# taxRule = NONE # All surrounded empty points are scored +# taxRule = SEKI # Eyes in seki do NOT count as points +# taxRule = ALL # All groups are taxed up to 2 points for the two eyes needed to live + +# Is multiple-stone suicide legal? (Single-stone suicide is always illegal). +# multiStoneSuicideLegal = false +# multiStoneSuicideLegal = true + +# "Button go" - the first pass when area scoring awards 0.5 points and does +# not count for ending the game. +# Allows area scoring rulesets that have far simpler rules to achieve the same +# final scoring precision and reward for precise play as territory scoring. +# hasButton = false +# hasButton = true + +# Is this a human ruleset where it's okay to pass before having physically +# captured and removed all dead stones? +# friendlyPassOk = false +# friendlyPassOk = true + +# How handicap stones in handicap games are compensated +# whiteHandicapBonus = 0 # White gets no compensation for black's handicap stones (Tromp-taylor, NZ, JP) +# whiteHandicapBonus = N-1 # White gets N-1 points for black's N handicap stones (AGA) +# whiteHandicapBonus = N # White gets N points for black's N handicap stones (Chinese) + +# ------------------------------ +# Other rules hacks +# ------------------------------ +# Uncomment and change to adjust what board size KataGo uses upon startup +# by default when GTP doesn't specify. +# defaultBoardSize = 19 + +# By default, Katago will use the komi that the GUI or GTP controller tries to set. +# Uncomment and set this to have KataGo ignore the controller and always use this komi. +# ignoreGTPAndForceKomi = 7 + +# =========================================================================== +# Bot behavior +# =========================================================================== + +# ------------------------------ +# Resignation +# ------------------------------ + +# Resignation occurs if for at least resignConsecTurns in a row, the +# winLossUtility (on a [-1,1] scale) is below resignThreshold. +allowResignation = true +resignThreshold = -0.90 +resignConsecTurns = 3 + +# By default, KataGo may resign games that it is confidently losing even if they +# are very close in score. Uncomment and set this to avoid resigning games +# if the estimated difference is points is less than or equal to this. +# resignMinScoreDifference = 10 + +# Disallow resignation if turn number < resignMinMovesPerBoardArea * area of board. +# e.g 0.25 would prohibit resignation on 19x19 until after turn 361 * 0.25 ~= 90. +# resignMinMovesPerBoardArea = 0.00 + +# ------------------------------ +# Handicap +# ------------------------------ +# Assume that if black makes many moves in a row right at the start of the +# game, then the game is a handicap game. This is necessary on some servers +# and for some GUIs and also when initializing from many SGF files, which may +# set up a handicap game using repeated GTP "play" commands for black rather +# than GTP "place_free_handicap" commands; however, it may also lead to +# incorrect understanding of komi if whiteHandicapBonus is used and a server +# does not have such a practice. Uncomment and set to false to disable. +# assumeMultipleStartingBlackMovesAreHandicap = true + +# Makes katago dynamically adjust in handicap or altered-komi games to assume +# based on those game settings that it must be stronger or weaker than the +# opponent and to play accordingly. Greatly improves handicap strength by +# biasing winrates and scores to favor appropriate safe/aggressive play. +# Does NOT affect analysis (lz-analyze, kata-analyze, used by programs like +# Lizzie) so analysis remains unbiased. Uncomment and set this to 0 to disable +# this and make KataGo play the same always. +# dynamicPlayoutDoublingAdvantageCapPerOppLead = 0.045 + +# Instead of "dynamicPlayoutDoublingAdvantageCapPerOppLead", you can comment +# that out and uncomment and set "playoutDoublingAdvantage" to a fixed value +# from -3.0 to 3.0 that will not change dynamically. +# ALSO affects analysis tools (lz-analyze, kata-analyze, used by e.g. Lizzie). +# Negative makes KataGo behave as if it is much weaker than the opponent. +# Positive makes KataGo behave as if it is much stronger than the opponent. +# KataGo will adjust to favor safe/aggressive play as appropriate based on +# the combination of who is ahead and how much stronger/weaker it thinks it is, +# and report winrates and scores taking the strength difference into account. +# +# If this and "dynamicPlayoutDoublingAdvantageCapPerOppLead" are both set +# then dynamic will be used for all games and this fixed value will be used +# for analysis tools. +# playoutDoublingAdvantage = 0.0 + +# Uncomment one of these when using "playoutDoublingAdvantage" to enforce +# that it will only apply when KataGo plays as the specified color and will be +# negated when playing as the opposite color. +# playoutDoublingAdvantagePla = BLACK +# playoutDoublingAdvantagePla = WHITE + +# ------------------------------ +# Passing and cleanup +# ------------------------------ +# Make the bot never assume that its pass will end the game, even if passing +# would end and "win" under Tromp-Taylor rules. Usually this is a good idea +# when using it for analysis or playing on servers where scoring may be +# implemented non-tromp-taylorly. Uncomment and set to false to disable. +# conservativePass = true + +# When using territory scoring, self-play games continue beyond two passes +# with special cleanup rules that may be confusing for human players. This +# option prevents the special cleanup phases from being reachable when using +# the bot for GTP play. Uncomment and set to false to enable entering special +# cleanup. For example, if you are testing it against itself, or against +# another bot that has precisely implemented the rules documented at +# https://lightvector.github.io/KataGo/rules.html +# preventCleanupPhase = true + +# ------------------------------ +# Miscellaneous behavior +# ------------------------------ +# If the board is symmetric, search only one copy of each equivalent move. +# Attempts to also account for ko/superko, will not theoretically perfect for +# superko. Uncomment and set to false to disable. +# rootSymmetryPruning = true + +# Uncomment and set to true to avoid a particular joseki that some networks +# misevaluate, and also to improve opening diversity versus some particular +# other bots that like to play it all the time. +# avoidMYTDaggerHack = false + +# Prefer to avoid playing the same joseki in every corner of the board. +# Uncomment to set to a specific value. See "Avoid SGF patterns" section. +# By default: 0 (even games), 0.005 (handicap games) +# avoidRepeatedPatternUtility = 0.0 + +# Experimental logic to fight against mirror Go even with unfavorable komi. +# Uncomment to set to a specific value to use for both playing and analysis. +# By default: true when playing via GTP, but false when analyzing. +# antiMirror = true + +# Enable some hacks that mitigate rare instances when passing messes up deeper searches. +# enablePassingHacks = true + +# Uncomment and set this to true to prevent bad or bogus move sequences +# in the history leading to this position from affecting KataGo's move choices. +# Same as analysisIgnorePreRootHistory (see above) but applies to actual play. +# You can enable this if KataGo is being asked to play from positions that it did not +# choose the moves to reach. +# ignorePreRootHistory = false + +# =========================================================================== +# Search limits +# =========================================================================== + +# Terminology: +# "Playouts" is the number of new playouts of search performed each turn. +# "Visits" is the same as "Playouts" but also counts search performed on +# previous turns that is still applicable to this turn. +# "Time" is the time in seconds. + +# For example, if KataGo searched 200 nodes on the previous turn, and then +# after the opponent's reply, 50 nodes of its search tree was still valid, +# then a visit limit of 200 would allow KataGo to search 150 new nodes +# (for a final tree size of 200 nodes), whereas a playout limit of of 200 +# would allow KataGo to search 200 nodes (for a final tree size of 250 nodes). + +# Additionally, KataGo may also move before than the limit in order to +# obey time controls (e.g. byo-yomi, etc) if the GTP controller has +# told KataGo that the game has is being played with a given time control. + +# Limits for search on the current turn. +# If commented out or unspecified, the default is to have no limit. +maxVisits = 500 +# maxPlayouts = 300 +# maxTime = 10.0 + +# Ponder on the opponent's turn? +ponderingEnabled = false + +# Limits for search when pondering on the opponent's turn. +# If commented out or unspecified, the default is to have no limit. +# Limiting the maximum time is recommended so that KataGo won't burn CPU/GPU +# forever and/or run out of RAM if left unattended while pondering is enabled. +# maxVisitsPondering = 5000 +# maxPlayoutsPondering = 3000 +maxTimePondering = 60.0 + + +# ------------------------------ +# Other search limits and behavior +# ------------------------------ + +# Approx number of seconds to buffer for lag for GTP time controls - will +# move a bit faster assuming there is this much lag per move. +lagBuffer = 1.0 + +# YOU PROBABLY WANT TO TUNE THIS PARAMETER! +# The number of threads to use when searching. On powerful GPUs the optimal +# threads may be much higher than the number of CPU cores you have because +# many threads are needed to feed efficient large batches to the GPU. +# +# Run "./katago benchmark" to tune this parameter and test the effect +# of changes to any of other parameters. +numSearchThreads = 6 + +# Play a little faster if the opponent is passing, for human-friendliness. +# Comment these out to disable them, such as if running a controlled match +# where you are testing KataGo with fixed compute per move vs other bots. +searchFactorAfterOnePass = 0.50 +searchFactorAfterTwoPass = 0.25 + +# Play a little faster if super-winning, for human-friendliness. +# Comment these out to disable them, such as if running a controlled match +# where you are testing KataGo with fixed compute per move vs other bots. +searchFactorWhenWinning = 0.40 +searchFactorWhenWinningThreshold = 0.95 + +# =========================================================================== +# GPU settings +# =========================================================================== +# This section configures GPU settings. +# +# Maximum number of positions to send to a single GPU at once. The default +# value is roughly equal to numSearchThreads, but can be specified manually +# if running out of memory, or using multiple GPUs that expect to share work. +# nnMaxBatchSize = + +# Controls the neural network cache size, which is the primary RAM/memory use. +# KataGo will cache up to (2 ** nnCacheSizePowerOfTwo) many neural net +# evaluations in case of transpositions in the tree. +# Increase this to improve performance for searches with tens of thousands +# of visits or more. Decrease this to limit memory usage. +# If you're happy to do some math - each neural net entry takes roughly +# 1.5KB, except when using whole-board ownership/territory +# visualizations, where each entry will take roughly 3KB. The number of +# entries is (2 ** nnCacheSizePowerOfTwo). (E.g. 2 ** 18 = 262144.) +# You can compute roughly how much memory the cache will use based on this. +# nnCacheSizePowerOfTwo = 20 + +# Size of mutex pool for nnCache is (2 ** this). +# nnMutexPoolSizePowerOfTwo = 16 + +# Randomize board orientation when running neural net evals? Uncomment and +# set to false to disable. +# nnRandomize = true + +# If provided, force usage of a specific seed for nnRandomize. +# The default is to use a randomly generated seed. +# nnRandSeed = abcdefg + +# Uncomment and set to true to force GTP to use the maximum board size for +# internal buffers for the neural net. This will make KataGo slower when +# evaluating small boards, but will avoid a lengthy initialization time on every +# change of board size due to having to re-size the neural net buffers on the GPU. +# This can be useful for example, for OGS's persistent bot mode that uses a single +# bot instance to handle multiple games and may thrash between different board sizes +# if there are concurrent games of multiple sizes. +# gtpForceMaxNNSize = false + +# ------------------------------ +# Multiple GPUs +# ------------------------------ +# Set this to the number of GPUs to use or that are available. +# IMPORTANT: If more than 1, also uncomment the appropriate TensorRT +# or CUDA or OpenCL section. +# numNNServerThreadsPerModel = 1 + +# ------------------------------ +# TENSORRT GPU settings +# ------------------------------ +# These only apply when using the TENSORRT version of KataGo. + +# For one GPU: optionally uncomment this option and change if the GPU to +# use is not device 0. +# trtDeviceToUse = 0 + +# For two GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1. +# trtDeviceToUseThread0 = 0 +# trtDeviceToUseThread1 = 1 + +# For three GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1 and 2. +# trtDeviceToUseThread0 = 0 +# trtDeviceToUseThread1 = 1 +# trtDeviceToUseThread2 = 2 + +# The pattern continues for additional GPUs. + +# ------------------------------ +# CUDA GPU settings +# ------------------------------ +# These only apply when using the CUDA version of KataGo. + +# For one GPU: optionally uncomment and change this if the GPU you want to +# use is not device 0 +# cudaDeviceToUse = 0 + +# For two GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1. +# cudaDeviceToUseThread0 = 0 +# cudaDeviceToUseThread1 = 1 + +# For three GPUs: Uncomment these options, AND set numNNServerThreadsPerModel above. +# Also, change their values if the devices you want to use are not 0 and 1 and 2. +# cudaDeviceToUseThread0 = 0 +# cudaDeviceToUseThread1 = 1 +# cudaDeviceToUseThread2 = 2 + +# The pattern continues for additional GPUs. + +# KataGo will automatically use FP16 or not based on the compute capability +# of your NVIDIA GPU. If you want to try to force a particular behavior +# you can uncomment these lines and change them to "true" or "false". +# cudaUseFP16 = auto +# cudaUseNHWC = auto + +# ------------------------------ +# Metal GPU settings +# ------------------------------ +# These only apply when using the METAL version of KataGo. + +# For one Metal instance: KataGo will automatically use the default device. +# metalDeviceToUse = 0 + +# For two Metal instance: Uncomment these options, AND set numNNServerThreadsPerModel = 2 above. +# This will create two Metal instances, best overlapping the GPU and CPU execution. +# metalDeviceToUseThread0 = 0 +# metalDeviceToUseThread1 = 1 + +# The pattern continues for additional Metal instances. + +# ------------------------------ +# OpenCL GPU settings +# ------------------------------ +# These only apply when using the OpenCL version of KataGo. + +# Uncomment and set to true to tune OpenCL for every board size separately, +# rather than only the largest possible size. +# openclReTunePerBoardSize = false + +# For one GPU: optionally uncomment and change this if the best device to use is guessed incorrectly. +# The default behavior tries to guess the 'best' GPU or device on your system to use, usually it will be a good guess. +# openclDeviceToUse = 0 + +# For two GPUs: Uncomment these two lines and replace X and Y with the device ids of the devices you want to use. +# It might NOT be 0 and 1, some computers will have many OpenCL devices. You can see what the devices are when +# KataGo starts up - it should print or log all the devices it finds. +# (AND also set numNNServerThreadsPerModel above) +# openclDeviceToUseThread0 = X +# openclDeviceToUseThread1 = Y + +# For three GPUs: Uncomment these three lines and replace X and Y and Z with the device ids of the devices you want to use. +# It might NOT be 0 and 1 and 2, some computers will have many OpenCL devices. You can see what the devices are when +# KataGo starts up - it should print or log all the devices it finds. +# (AND also set numNNServerThreadsPerModel above) +# openclDeviceToUseThread0 = X +# openclDeviceToUseThread1 = Y +# openclDeviceToUseThread2 = Z + +# The pattern continues for additional GPUs. + +# KataGo will automatically use FP16 or not based on testing your GPU during +# tuning. If you want to try to force a particular behavior though you can +# uncomment this option and change it to "true" or "false". This is a fairly +# blunt setting - more detailed settings are testable by rerunning the tuner +# with various arguments (./katago tuner). +# openclUseFP16 = auto + +# ------------------------------ +# Eigen-specific settings +# ------------------------------ +# These only apply when using the Eigen (pure CPU) version of KataGo. + +# Number of CPU threads for evaluating the neural net on the Eigen backend. +# +# Default: numSearchThreads +# numEigenThreadsPerModel = X + +# =========================================================================== +# Root move selection and biases +# =========================================================================== +# Uncomment and edit any of the below values to change them from their default. + +# If provided, force usage of a specific seed for various random things in +# the search. The default is to use a random seed. +# searchRandSeed = hijklmn + +# Temperature for the early game, randomize between chosen moves with +# this temperature +# chosenMoveTemperatureEarly = 0.5 + +# Decay temperature for the early game by 0.5 every this many moves, +# scaled with board size. +# chosenMoveTemperatureHalflife = 19 + +# At the end of search after the early game, randomize between chosen +# moves with this temperature +# chosenMoveTemperature = 0.10 + +# Subtract this many visits from each move prior to applying +# chosenMoveTemperature (unless all moves have too few visits) to downweight +# unlikely moves +# chosenMoveSubtract = 0 + +# The same as chosenMoveSubtract but only prunes moves that fall below +# the threshold. This setting does not affect chosenMoveSubtract. +# chosenMovePrune = 1 + +# Number of symmetries to sample (without replacement) and average at the root +# rootNumSymmetriesToSample = 1 + +# Using LCB for move selection? +# useLcbForSelection = true + +# How many stdevs a move needs to be better than another for LCB selection +# lcbStdevs = 5.0 + +# Only use LCB override when a move has this proportion of visits as the +# top move. +# minVisitPropForLCB = 0.15 + +# =========================================================================== +# Internal params +# =========================================================================== +# Uncomment and edit any of the below values to change them from their default. + +# Scales the utility of winning/losing +# winLossUtilityFactor = 1.0 + +# Scales the utility for trying to maximize score +# staticScoreUtilityFactor = 0.10 +# dynamicScoreUtilityFactor = 0.30 + +# Adjust dynamic score center this proportion of the way towards zero, +# capped at a reasonable amount. +# dynamicScoreCenterZeroWeight = 0.20 +# dynamicScoreCenterScale = 0.75 + +# The utility of getting a "no result" due to triple ko or other long cycle +# in non-superko rulesets (-1 to 1) +# noResultUtilityForWhite = 0.0 + +# The number of wins that a draw counts as, for white. (0 to 1) +# drawEquivalentWinsForWhite = 0.5 + +# Exploration constant for mcts +# cpuctExploration = 1.0 +# cpuctExplorationLog = 0.45 + +# Parameters that control exploring more in volatile positions, exploring +# less in stable positions. +# cpuctUtilityStdevPrior = 0.40 +# cpuctUtilityStdevPriorWeight = 2.0 +# cpuctUtilityStdevScale = 0.85 + +# FPU reduction constant for mcts +# fpuReductionMax = 0.2 +# rootFpuReductionMax = 0.1 +# fpuParentWeightByVisitedPolicy = true + +# Parameters that control weighting of evals based on the net's own +# self-reported uncertainty. +# useUncertainty = true +# uncertaintyExponent = 1.0 +# uncertaintyCoeff = 0.25 + +# Explore using optimistic policy +# rootPolicyOptimism = 0.2 +# policyOptimism = 1.0 + +# Amount to apply a downweighting of children with very bad values relative +# to good ones. +# valueWeightExponent = 0.25 + +# Slight incentive for the bot to behave human-like with regard to passing at +# the end, filling the dame, not wasting time playing in its own territory, +# etc., and not play moves that are equivalent in terms of points but a bit +# more unfriendly to humans. +# rootEndingBonusPoints = 0.5 + +# Make the bot prune useless moves that are just prolonging the game to +# avoid losing yet. +# rootPruneUselessMoves = true + +# Apply bias correction based on local pattern keys +# subtreeValueBiasFactor = 0.45 +# subtreeValueBiasWeightExponent = 0.85 + +# Use graph search rather than tree search - identify and share search for +# transpositions. +# useGraphSearch = true + +# How much to shard the node table for search synchronization +# nodeTableShardsPowerOfTwo = 16 + +# How many virtual losses to add when a thread descends through a node +# numVirtualLossesPerThread = 1 + +# Improve the quality of evals under heavy multithreading +# useNoisePruning = true + +# =========================================================================== +# Automatic avoid patterns +# =========================================================================== +# The parameters in this section provide a way to bias away from moves that match +# patterns that this instance of KataGo has played in previous games, by auto-saving +# moves to a directory and then auto-loading and biasing against them each new game. +# Uncomment them to use them. When using this feature, all parameters must be specified. + +# Different board sizes are tracked separately, but all board sizes share the same +# parameters by default. Every parameter *except* for autoAvoidRepeatDir and +# autoAvoidRepeatSaveChunkSize can be overridden per board size, +# e.g. "autoAvoidRepeatMinTurnNumber13x13". You must ALSO specify the +# defaults even if you have specified board-size-specific values. + +# Directory to auto-save moves KataGo plays, to avoid them in future games. +# You can create a new empty directory and put its path here. +# If you run parallel instances of KataGo, use different directories if you +# want them not to share their biases, use the same directory if you want +# all of them to bias away from past moves that any of them have played. +# KataGo will also automatically DELETE old data in this directory, so it is +# recommended that if you do share the same directory between parallel instances, +# that they all use the same settings that affect data saving/deletion. +# autoAvoidRepeatDir = PATH_TO_NEW_DIRECTORY + +# Penalize this much utility per matching move. +# Values that are too large may lead to bad play. The value of 0.004 is fairly large +# and might be large enough to result in some early weird/bad moves when trying +# to avoid past games' moves if enough games begin the same way. You can experiment with it. +# autoAvoidRepeatUtility = 0.004 + +# Per each new move saved, exponentially decay prior saved moves by this factor. +# This way, the bias against moves from many games ago is gradually phased out. +# For example, 0.9995 = 1 - 1/2000, so would be roughly 2000 prior moves worth of +# penalty weight remembered, in steady state. Depending on what turn number range +# you are saving, this might equate to a different number of games For example +# saving the first 50 moves per game would make this roughly 2000 / 50 = 40 games +# worth of memory. +# autoAvoidRepeatLambda = 0.9995 + +# Affects data saving/deletion. +# When the number of saved moves exceeds this, outright delete them to avoid too many +# files and disk space building up. Also may affect the speed of saving/loading on start +# of each game if this is set large and a lot of data builds up. +# autoAvoidRepeatMaxPoses = 10000 + +# Affects data saving/deletion. +# Only save data for moves within this turn number range of those games. +# E.g. setting autoAvoidRepeatMinTurnNumber to a number like 4 or 5 would tend to make +# KataGo not develop a bias against the initial 3-4 and 4-4 corner moves in almost every game. +# autoAvoidRepeatMinTurnNumber = 0 +# autoAvoidRepeatMaxTurnNumber = 50 + +# Affects data saving/deletion. +# Within a single run of a program, wait to accumulate this many samples +# (possibly across multiple clear_boards/games) before saving the data. +# Can help to avoid writing too many small files to disk, especially when GTP is used +# in a way that clears the board very frequently (e.g. gtp2ogs pooled manager). +# autoAvoidRepeatSaveChunkSize = 200 + +# =========================================================================== +# Avoid SGF patterns +# =========================================================================== +# The parameters in this section provide a way to avoid moves that follow +# specific patterns based on a set of SGF files loaded upon startup. +# This is basically the same as the above "Automatic avoid patterns" section +# above except you supply your own SGF files to avoid moves from. +# Uncomment them to use this feature. Additionally, if the SGF file +# contains the string %SKIP% in a comment on a move, that move will be +# ignored for this purpose. + +# Load SGF files from this directory when the engine is started +# (only on startup, will not reload unless engine is restarted) +# avoidSgfPatternDirs = path/to/directory/with/sgfs/ +# You can also surround the file path in double quotes if the file path contains trailing spaces or hash signs. +# Within double quotes, backslashes are escape characters. +# avoidSgfPatternDirs = "path/to/directory/with/sgfs/" + +# Penalize this much utility per matching move. +# Set this negative if you instead want to favor SGF patterns instead of +# penalizing them. This number does not need to be large, even 0.001 will +# make a difference. Values that are too large may lead to bad play. +# avoidSgfPatternUtility = 0.001 + +# Optional - load only the newest this many files +# avoidSgfPatternMaxFiles = 20 + +# Optional - Penalty is multiplied by this per each older SGF file, so that +# old SGF files matter less than newer ones. +# avoidSgfPatternLambda = 0.90 + +# Optional - pay attention only to moves made by players with this name. +# For example, set it to the name that your bot's past games will show up +# as in the SGF, so that the bot will only avoid repeating moves that itself +# made in past games, not the moves that its opponents made. +# avoidSgfPatternAllowedNames = my-ogs-bot-name1,my-ogs-bot-name2 + +# Optional - Ignore moves in SGF files that occurred before this turn number. +# avoidSgfPatternMinTurnNumber = 0 + +# For more avoid patterns: +# You can also specify a second set of parameters, and a third, fourth, +# etc. by numbering 2,3,4,... +# +# avoidSgf2PatternDirs = ... +# avoidSgf2PatternUtility = ... +# avoidSgf2PatternMaxFiles = ... +# avoidSgf2PatternLambda = ... +# avoidSgf2PatternAllowedNames = ... +# avoidSgf2PatternMinTurnNumber = ... + diff --git a/cpp/configs/training/gatekeeper1_dots.cfg b/cpp/configs/training/gatekeeper1_dots.cfg new file mode 100644 index 000000000..dc9eee536 --- /dev/null +++ b/cpp/configs/training/gatekeeper1_dots.cfg @@ -0,0 +1,117 @@ + +# Logs------------------------------------------------------------------------------------ + +logSearchInfo = false +logMoves = false +logGamesEvery = 10 +logToStdout = true + +# Fancy game selfplay settings-------------------------------------------------------------------- + +# startPosesProb = 0.0 # Play this proportion of games starting from SGF positions +# startPosesFromSgfDir = DIRECTORYPATH # Load SGFs from this dir +# startPosesLoadProb = 1.0 # Only load each position from each SGF with this chance (save memory) +# startPosesTurnWeightLambda = 0 # 0 = equal weight 0.01 = decrease probability by 1% per turn -0.01 = increase probability by 1% per turn. + +# Match----------------------------------------------------------------------------------- + +numGameThreads = 16 +maxMovesPerGame = 1600 +numGamesPerGating = 100 + +allowResignation = true +resignThreshold = -0.90 +resignConsecTurns = 5 + +# Disabled, since we're not using any root noise and such +# Could have a slight weirdness on rootEndingBonusPoints, but shouldn't be a big deal. +# clearBotBeforeSearch = true + +# Rules------------------------------------------------------------------------------------ + +dots = true +multiStoneSuicideLegals = false,true + +#bSizesXY=20-20,36-36,39-32,24-24,30-30 +#bSizeRelProbs=10,2,1 + +bSizesXY=10-10,12-12,14-14 +bSizeRelProbs=1,1,2 + +startPoses=CROSS,CROSS,CROSS,CROSS,CROSS,CROSS,CROSS_2 +startPosesAreRandom=false,true +dotsCaptureEmptyBases=false,false,false,false,false,true +#dotsFreeCapturedDots=true,true,true,false + +#komiAuto = True # Automatically adjust komi to what the neural nets think are fair +komiMean = 0.0 # Specify explicit komi +komiStdev = 0.5 # Standard deviation of random variation to komi. +handicapProb = 0.0 # Probability of handicap game +handicapCompensateKomiProb = 0.0 # Probability of compensating komi to fair during handicap game +# numExtraBlackFixed = 3 # When playing handicap games, always use exactly this many extra black moves + +# Search limits----------------------------------------------------------------------------------- +maxVisits = 150 +numSearchThreads = 16 + +# GPU Settings------------------------------------------------------------------------------- + +nnMaxBatchSize = 128 +nnCacheSizePowerOfTwo = 21 +nnMutexPoolSizePowerOfTwo = 15 +numNNServerThreadsPerModel = 1 +nnRandomize = true + +# CUDA GPU settings-------------------------------------- +# cudaDeviceToUse = 0 #use device 0 for all server threads (numNNServerThreadsPerModel) unless otherwise specified per-model or per-thread-per-model +# cudaDeviceToUseModel0 = 3 #use device 3 for model 0 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel1 = 2 #use device 2 for model 1 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel0Thread0 = 3 #use device 3 for model 0, server thread 0 +# cudaDeviceToUseModel0Thread1 = 2 #use device 2 for model 0, server thread 1 + +cudaUseFP16 = auto +cudaUseNHWC = auto + +# Root move selection and biases------------------------------------------------------------------------------ + +chosenMoveTemperatureEarly = 0.5 +chosenMoveTemperatureHalflife = 19 +chosenMoveTemperature = 0.2 +chosenMoveSubtract = 0 +chosenMovePrune = 1 + +useLcbForSelection = true +lcbStdevs = 5.0 +minVisitPropForLCB = 0.15 + +# Internal params------------------------------------------------------------------------------ + +winLossUtilityFactor = 1.0 +staticScoreUtilityFactor = 0.00 +dynamicScoreUtilityFactor = 0.25 +dynamicScoreCenterZeroWeight = 0.25 +dynamicScoreCenterScale = 0.50 +noResultUtilityForWhite = 0.0 +drawEquivalentWinsForWhite = 0.5 + +rootEndingBonusPoints = 0.5 +rootPruneUselessMoves = true + +cpuctExploration = 1.1 +cpuctExplorationLog = 0.0 +fpuReductionMax = 0.2 +rootFpuReductionMax = 0.1 +valueWeightExponent = 0.5 +policyOptimism = 1.0 +rootPolicyOptimism = 0.0 + +subtreeValueBiasFactor = 0.35 +subtreeValueBiasWeightExponent = 0.8 +useNonBuggyLcb = true +useGraphSearch = true + +useUncertainty = true +uncertaintyExponent = 1.0 +uncertaintyCoeff = 0.25 + +numVirtualLossesPerThread = 1 diff --git a/cpp/configs/training/selfplay1_dots.cfg b/cpp/configs/training/selfplay1_dots.cfg new file mode 100644 index 000000000..77f535726 --- /dev/null +++ b/cpp/configs/training/selfplay1_dots.cfg @@ -0,0 +1,196 @@ + +# Logs------------------------------------------------------------------------------------ + +logSearchInfo = false +logMoves = false +logGamesEvery = 1 +logToStdout = true + +# Data writing----------------------------------------------------------------------------------- + +# Spatial size of tensors of data, must match -pos-len in python/train.py and be at least as large as the +# largest boardsize in the data. If you train only on smaller board sizes, you can decrease this and +# pass in a smaller -pos-len to python/train.py (e.g. modify "-pos-len 19" in selfplay/train.sh script), +# to train faster, although it may be awkward if you ever want to increase it again later since data with +# different values cannot be shuffled together. +dataBoardLen = 39 +dataBoardLenY = 39 # TODO: fix handling of rectangular size + +maxDataQueueSize = 2000 +maxRowsPerTrainFile = 10000 +firstFileRandMinProp = 0.15 + +# Fancy game selfplay settings-------------------------------------------------------------------- + +# These take dominance - if a game is forked for any of these reasons, that will be the next game played. +# For forked games, randomization of rules and "initGamesWithPolicy" is disabled, komi is made fair with prob "forkCompensateKomiProb". +earlyForkGameProb = 0.04 # Fork to try alternative opening variety with this probability +earlyForkGameExpectedMoveProp = 0.025 # Fork roughly within the first (boardArea * this) many moves +forkGameProb = 0.01 # Fork to try alternative crazy move anywhere in game with this probability if not early forking +forkGameMinChoices = 3 # Choose from the best of at least this many random choices +earlyForkGameMaxChoices = 12 # Choose from the best of at most this many random choices +forkGameMaxChoices = 36 # Choose from the best of at most this many random choices +# Otherwise, with this probability, learn a bit more about the differing evaluation of seki in different rulesets. +sekiForkHackProb = 0.02 + +# Otherwise, play some proportion of games starting from SGF positions, with randomized rules (ignoring the sgf rules) +# On SGF positions, high temperature policy init is allowed +# startPosesProb = 0.0 # Play this proportion of games starting from SGF positions +# startPosesFromSgfDir = DIRECTORYPATH # Load SGFs from this dir +# startPosesLoadProb = 1.0 # Only load each position from each SGF with this chance (save memory) +# startPosesTurnWeightLambda = 0 # 0 = equal weight 0.01 = decrease probability by 1% per turn -0.01 = increase probability by 1% per turn. +# startPosesPolicyInitAreaProp = 0.0 # Same as policyInitAreaProp but for SGF positions + +# Otherwise, play some proportion of games starting from hint positions (generated using "dataminesgfs" command), with randomized rules. +# On hint positions, "initGamesWithPolicy" does not apply. +# hintPosesProb = 0.0 +# hintPosesDir = DIRECTORYPATH + +# Otherwise we are playing a "normal" game, potentially with handicap stones, depending on "handicapProb", and +# potentially with komi randomization, depending on things like "komiStdev", and potentially with different +# board sizes, etc. + +# Most of the remaining parameters here below apply regardless of the initialization, although a few of them +# vary depending on handicap vs normal game, and some are explicitly disabled (e.g. initGamesWithPolicy on hint positions). + +initGamesWithPolicy = true # Play the first few moves of a game high-temperaturely from policy +policyInitAreaProp = 0.04 # The avg number of moves to play +compensateAfterPolicyInitProb = 0.2 # Additionally make komi fair this often after the high-temperature moves. +sidePositionProb = 0.020 # With this probability, train on refuting bad alternative moves. + +cheapSearchProb = 0.75 # Do cheap searches with this probaiblity +cheapSearchVisits = 100 # Number of visits for cheap search +cheapSearchTargetWeight = 0.0 # Training weight for cheap search + +reduceVisits = true # Reduce visits when one side is winning +reduceVisitsThreshold = 0.9 # How winning a side needs to be (winrate) +reduceVisitsThresholdLookback = 3 # How many consecutive turns needed to be that winning +reducedVisitsMin = 100 # Minimum number of visits (never reduce below this) +reducedVisitsWeight = 0.1 # Minimum training weight + +handicapAsymmetricPlayoutProb = 0.5 # In handicap games, play with unbalanced players with this probablity +normalAsymmetricPlayoutProb = 0.01 # In regular games, play with unbalanced players with this probability +maxAsymmetricRatio = 8.0 # Max ratio to unbalance visits between players +minAsymmetricCompensateKomiProb = 0.4 # Compensate komi with at least this probability for unbalanced players + +policySurpriseDataWeight = 0.5 # This proportion of training weight should be concentrated on surprising moves +valueSurpriseDataWeight = 0.1 # This proportion of training weight should be concentrated on surprising position results + +estimateLeadProb = 0.05 # Train lead, rather than just scoremean. Consumes a decent number of extra visits, can be quite slow using low visits to set too high. +switchNetsMidGame = true # When a new neural net is loaded, switch to it immediately instead of waiting for new game +# fancyKomiVarying = true # In non-compensated handicap and fork games, vary komi to better learn komi and large score differences that would never happen in even games. + +# Match----------------------------------------------------------------------------------- + +numGameThreads = 16 +maxMovesPerGame = 1600 + +# Rules------------------------------------------------------------------------------------ + +dots = true +multiStoneSuicideLegals = false,true + +#bSizesXY=20-20,36-36,39-32,24-24,30-30 +#bSizeRelProbs=10,2,1 + +#bSizesXY=20-20,16-16,10-10,24-24 +#bSizeRelProbs=3,2,2,1 + +#bSizesXY=8-8,10-10,9-9,8-10,10-9,12-12 +#bSizeRelProbs=4,4,2,1,1,1 + +bSizesXY=10-10,12-12,14-14 +bSizeRelProbs=1,2,1 + +startPoses=CROSS,CROSS,CROSS,CROSS_2,CROSS_4,CROSS_4 +startPosesAreRandom=false +dotsCaptureEmptyBases=false,false,false,false,false,false,false,true +#dotsFreeCapturedDots=true,true,true,false + +# komiAuto = True # Automatically adjust komi to what the neural nets think are fair based on the empty board, but still apply komiStdev. +komiMean = 0.0 # Specify explicit komi +komiStdev = 0.5 # Standard deviation of random variation to komi. +komiBigStdevProb = 0.0 # Probability of applying komiBigStdev +komiBigStdev = 0.0 # Standard deviation of random big variation to komi + +handicapProb = 0.1 # Probability of handicap game +handicapCompensateKomiProb = 0.00 # In handicap games, adjust komi to fair with this probability based on the handicap placement +forkCompensateKomiProb = 0.00 # For forks, adjust komi to fair with this probability based on the forked position +sgfCompensateKomiProb = 0.00 # For sgfs, adjust komi to fair with this probability based on the specific starting position + +drawRandRadius = 0.5 +noResultStdev = 0.166666666 + +# Search limits----------------------------------------------------------------------------------- + +maxVisits = 600 +numSearchThreads = 16 + +# GPU Settings------------------------------------------------------------------------------- + +nnMaxBatchSize = 128 +nnCacheSizePowerOfTwo = 21 +nnMutexPoolSizePowerOfTwo = 15 +numNNServerThreadsPerModel = 1 +nnRandomize = true + +# CUDA GPU settings-------------------------------------- +# cudaDeviceToUse = 0 #use device 0 for all server threads (numNNServerThreadsPerModel) unless otherwise specified per-model or per-thread-per-model +# cudaDeviceToUseModel0 = 3 #use device 3 for model 0 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel1 = 2 #use device 2 for model 1 for all threads unless otherwise specified per-thread for this model +# cudaDeviceToUseModel0Thread0 = 3 #use device 3 for model 0, server thread 0 +# cudaDeviceToUseModel0Thread1 = 2 #use device 2 for model 0, server thread 1 + +cudaUseFP16 = auto +cudaUseNHWC = auto + +# Root move selection and biases------------------------------------------------------------------------------ + +chosenMoveTemperatureEarly = 0.75 +chosenMoveTemperatureHalflife = 19 +chosenMoveTemperature = 0.15 +chosenMoveSubtract = 0 +chosenMovePrune = 1 + +rootNoiseEnabled = true +rootDirichletNoiseTotalConcentration = 10.83 +rootDirichletNoiseWeight = 0.25 + +rootDesiredPerChildVisitsCoeff = 2 +rootNumSymmetriesToSample = 4 + +useLcbForSelection = true +lcbStdevs = 5.0 +minVisitPropForLCB = 0.15 + +# Internal params------------------------------------------------------------------------------ + +winLossUtilityFactor = 1.0 +staticScoreUtilityFactor = 0.00 +dynamicScoreUtilityFactor = 0.40 +dynamicScoreCenterZeroWeight = 0.25 +dynamicScoreCenterScale = 0.50 +noResultUtilityForWhite = 0.0 +drawEquivalentWinsForWhite = 0.5 + +rootEndingBonusPoints = 0.5 +rootPruneUselessMoves = true + +rootPolicyTemperatureEarly = 1.25 +rootPolicyTemperature = 1.1 + +cpuctExploration = 1.1 +cpuctExplorationLog = 0.0 +fpuReductionMax = 0.2 +rootFpuReductionMax = 0.0 + +numVirtualLossesPerThread = 1 + +# These parameters didn't exist historically during early KataGo runs +valueWeightExponent = 0.5 +subtreeValueBiasFactor = 0.30 +subtreeValueBiasWeightExponent = 0.8 +useNonBuggyLcb = true +useGraphSearch = true +fpuParentWeightByVisitedPolicy = true +fpuParentWeightByVisitedPolicyPow = 2.0 diff --git a/python/selfplay/synchronous_loop.sh b/python/selfplay/synchronous_loop.sh index 255862aba..a894deab9 100755 --- a/python/selfplay/synchronous_loop.sh +++ b/python/selfplay/synchronous_loop.sh @@ -54,21 +54,21 @@ mkdir -p "$BASEDIR"/gatekeepersgf # you have strong hardware or are later into a run you may want to reduce the overhead by scaling # these numbers up and doing more games and training per cycle, exporting models less frequently, etc. -NUM_GAMES_PER_CYCLE=500 # Every cycle, play this many games -NUM_THREADS_FOR_SHUFFLING=8 -NUM_TRAIN_SAMPLES_PER_EPOCH=100000 # Training will proceed in chunks of this many rows, subject to MAX_TRAIN_PER_DATA. +NUM_GAMES_PER_CYCLE=400 # Every cycle, play this many games +NUM_THREADS_FOR_SHUFFLING=16 +NUM_TRAIN_SAMPLES_PER_EPOCH=50000 # Training will proceed in chunks of this many rows, subject to MAX_TRAIN_PER_DATA. MAX_TRAIN_PER_DATA=8 # On average, train only this many times on each data row. Larger numbers may cause overfitting. NUM_TRAIN_SAMPLES_PER_SWA=80000 # Stochastic weight averaging frequency. BATCHSIZE=128 # For lower-end GPUs 64 or smaller may be needed to avoid running out of GPU memory. -SHUFFLE_MINROWS=100000 # Require this many rows at the very start before beginning training. +SHUFFLE_MINROWS=50000 # Require this many rows at the very start before beginning training. MAX_TRAIN_SAMPLES_PER_CYCLE=500000 # Each cycle will do at most this many training steps. -TAPER_WINDOW_SCALE=50000 # Parameter setting the scale at which the shuffler will make the training window grow sublinearly. +TAPER_WINDOW_SCALE=25000 # Parameter setting the scale at which the shuffler will make the training window grow sublinearly. SHUFFLE_KEEPROWS=600000 # Needs to be larger than MAX_TRAIN_SAMPLES_PER_CYCLE, so the shuffler samples enough rows each cycle for the training to use. # Paths to the selfplay and gatekeeper configs that contain board sizes, rules, search parameters, etc. # See cpp/configs/training/README.md for some notes on other selfplay configs. -SELFPLAY_CONFIG="$GITROOTDIR"/cpp/configs/training/selfplay1.cfg -GATING_CONFIG="$GITROOTDIR"/cpp/configs/training/gatekeeper1.cfg +SELFPLAY_CONFIG="$GITROOTDIR"/cpp/configs/training/selfplay1_dots.cfg +GATING_CONFIG="$GITROOTDIR"/cpp/configs/training/gatekeeper1_dots.cfg # Copy all the relevant scripts and configs and the katago executable to a dated directory. # For archival and logging purposes - you can look back and see exactly the python code on a particular date diff --git a/python/selfplay/train.sh b/python/selfplay/train.sh index 40bf72279..7e8da268e 100755 --- a/python/selfplay/train.sh +++ b/python/selfplay/train.sh @@ -77,7 +77,7 @@ time python3 ./train.py \ -latestdatadir "$BASEDIR"/shuffleddata/ \ -exportdir "$BASEDIR"/"$EXPORT_SUBDIR" \ -exportprefix "$TRAININGNAME" \ - -pos-len 19 \ + -pos-len 39 \ -batch-size "$BATCHSIZE" \ -model-kind "$MODELKIND" \ $EXTRAFLAG \ From 7c923c13ab4f3fc876b4b5e700ee2b2f403175aa Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 7 Nov 2025 14:16:21 +0100 Subject: [PATCH 35/42] Support for `RESIGN_LOC` The main reason is to support it in GTP protocol Don't generate the resignation move in utils because normally it makes little sense --- cpp/command/genbook.cpp | 2 +- cpp/dataio/sgf.cpp | 5 ++++- cpp/game/board.cpp | 37 ++++++++++++++++++++++---------- cpp/game/board.h | 6 ++++-- cpp/game/boardhistory.cpp | 23 +++++++++++++------- cpp/game/common.h | 2 ++ cpp/game/dotsfield.cpp | 10 +++++++-- cpp/neuralnet/nninputs.cpp | 6 +++--- cpp/program/playutils.cpp | 4 ++-- cpp/search/localpattern.cpp | 4 ++-- cpp/search/patternbonustable.cpp | 4 ++-- cpp/tests/testboardbasic.cpp | 37 +++++++++++++++++++++----------- 12 files changed, 93 insertions(+), 47 deletions(-) diff --git a/cpp/command/genbook.cpp b/cpp/command/genbook.cpp index d49b75faa..5d18e0fe7 100644 --- a/cpp/command/genbook.cpp +++ b/cpp/command/genbook.cpp @@ -570,7 +570,7 @@ int MainCmds::genbook(const vector& args) { Board board = hist.getRecentBoard(0); bool hasAtLeastOneLegalNewMove = false; for(Loc moveLoc = 0; moveLoc < Board::MAX_ARR_SIZE; moveLoc++) { - if(hist.isLegal(board,moveLoc,pla)) { + if(moveLoc != Board::RESIGN_LOC && hist.isLegal(board,moveLoc,pla)) { if(!isReExpansion && constNode.isMoveInBook(moveLoc)) avoidMoveUntilByLoc[moveLoc] = 1; else diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 4da0af4c4..7216d4696 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -127,8 +127,11 @@ static Loc parseSgfLocOrPass(const string& s, int xSize, int ySize) { static void writeSgfLoc(ostream& out, Loc loc, int xSize, int ySize) { if(xSize >= 53 || ySize >= 53) throw StringError("Writing coordinates for SGF files for board sizes >= 53 is not implemented"); - if(loc == Board::PASS_LOC || loc == Board::NULL_LOC) + if (loc == Board::PASS_LOC || loc == Board::NULL_LOC) return; + if (loc == Board::RESIGN_LOC) { + out << RESIGN_STR; + } int x = Location::getX(loc,xSize); int y = Location::getY(loc,xSize); const char* chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index eac1bfbd7..55ba016b7 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -72,7 +72,7 @@ bool Location::isAdjacent(Loc loc0, Loc loc1, int x_size) } Loc Location::getMirrorLoc(Loc loc, int x_size, int y_size) { - if(loc == Board::NULL_LOC || loc == Board::PASS_LOC) + if(loc == Board::NULL_LOC || loc == Board::PASS_LOC || loc == Board::RESIGN_LOC) return loc; return getLoc(x_size-1-getX(loc,x_size),y_size-1-getY(loc,x_size),x_size); } @@ -314,8 +314,10 @@ int Board::getNumLiberties(Loc loc) const //Check if moving here would be a self-capture bool Board::isSuicide(Loc loc, Player pla) const { - if(loc == PASS_LOC) + if (loc == PASS_LOC) return false; + if (loc == RESIGN_LOC) + return true; // It's seems reasonable to treat resigning as suicide if (rules.isDots) { return isSuicideDots(loc, pla); @@ -511,7 +513,10 @@ bool Board::isLegal(Loc loc, Player pla, bool isMultiStoneSuicideLegal, const bo { if(pla != P_BLACK && pla != P_WHITE) return false; - return loc == PASS_LOC || ( + if (loc == RESIGN_LOC) { + return true; + } + return loc == PASS_LOC || loc == RESIGN_LOC || ( loc >= 0 && loc < MAX_ARR_SIZE && getColor(loc) == C_EMPTY && @@ -864,7 +869,7 @@ Board::MoveRecord Board::playMoveRecorded(const Loc loc, const Player pla) { uint8_t capDirs = 0; - if(loc != PASS_LOC) { + if(loc != PASS_LOC && loc != RESIGN_LOC) { Player opp = getOpp(pla); { int adj = loc + ADJ0; @@ -904,7 +909,7 @@ void Board::undo(MoveRecord& record) ko_loc = record.ko_loc; Loc loc = record.loc; - if(loc == PASS_LOC) + if(loc == PASS_LOC || loc == RESIGN_LOC) return; //Re-fill stones in all captured directions @@ -1020,7 +1025,7 @@ void Board::undo(MoveRecord& record) } Hash128 Board::getPosHashAfterMove(Loc loc, Player pla) const { - if(loc == PASS_LOC) + if(loc == PASS_LOC || loc == RESIGN_LOC) return pos_hash; assert(loc != NULL_LOC); @@ -1106,8 +1111,8 @@ void Board::playMoveAssumeLegal(Loc loc, Player pla) { return; } - //Pass? - if(loc == PASS_LOC) + // Pass or resign? + if(loc == PASS_LOC || loc == RESIGN_LOC) { ko_loc = NULL_LOC; return; @@ -2625,6 +2630,8 @@ Player PlayerIO::parsePlayer(const string& s) { string Location::toStringMach(const Loc loc, const int x_size, const bool isDots) { if(loc == Board::PASS_LOC) return isDots ? "ground" : "pass"; + if (loc == Board::RESIGN_LOC) + return RESIGN_STR; if(loc == Board::NULL_LOC) return string("null"); char buf[128]; @@ -2633,9 +2640,11 @@ string Location::toStringMach(const Loc loc, const int x_size, const bool isDots } string Location::toString(const Loc loc, const int x_size, const int y_size, const bool isDots) { - if(loc == Board::PASS_LOC) + if (loc == Board::PASS_LOC) return isDots ? "ground" : "pass"; - if(loc == Board::NULL_LOC) + if (loc == Board::RESIGN_LOC) + return RESIGN_STR; + if (loc == Board::NULL_LOC) return "null"; const int x = getX(loc, x_size); const int y = getY(loc, x_size); @@ -2663,6 +2672,12 @@ bool Location::tryOfString(const string& str, const int x_size, const int y_size result = Board::PASS_LOC; return true; } + + if (Global::isEqualCaseInsensitive(s, RESIGN_STR)) { + result = Board::RESIGN_LOC; + return true; + } + if(s[0] == '(') { if(s[s.length() - 1] != ')') return false; @@ -2984,7 +2999,7 @@ bool Board::countEmptyHelper(bool* emptyCounted, Loc initialLoc, int& count, int } bool Board::simpleRepetitionBoundGt(Loc loc, int bound) const { - if(loc == NULL_LOC || loc == PASS_LOC) + if(loc == NULL_LOC || loc == PASS_LOC || loc == RESIGN_LOC) return false; if (rules.isDots) { diff --git a/cpp/game/board.h b/cpp/game/board.h index 5270d12c8..311b0735e 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -144,10 +144,12 @@ struct Board static constexpr int MAX_PLAY_SIZE = MAX_LEN_X * MAX_LEN_Y; //Maximum number of playable spaces static constexpr int MAX_ARR_SIZE = getMaxArrSize(MAX_LEN_X, MAX_LEN_Y); //Maximum size of arrays needed - //Location used to indicate an invalid spot on the board. + // Location used to indicate an invalid spot on the board. static constexpr Loc NULL_LOC = 0; - //Location used to indicate a pass or grounding (Dots game) move is desired. + // Location used to indicate a pass or grounding (Dots game) move is desired. static constexpr Loc PASS_LOC = 1; + // Location used to indicate resigning move. + static constexpr Loc RESIGN_LOC = 2; //Zobrist Hashing------------------------------ static bool IS_ZOBRIST_INITALIZED; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index 1f918c8b4..fd4ea6dc8 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -441,10 +441,12 @@ int BoardHistory::computeNumHandicapStones() const { //moves are interleaved with white passes. Ignore it and continue. if(moveLoc == Board::PASS_LOC) continue; + if (moveLoc == Board::RESIGN_LOC) + continue; // Actually shouldn't be here //Otherwise quit out, we have a normal white move. break; } - if(moveLoc != Board::PASS_LOC && moveLoc != Board::NULL_LOC) + if(moveLoc != Board::PASS_LOC && moveLoc != Board::NULL_LOC && moveLoc != Board::RESIGN_LOC) blackNonPassTurnsToStart += 1; } } @@ -486,6 +488,9 @@ void BoardHistory::printBasicInfo(ostream& out, const Board& board) const { out << firstPlayerName << " score: " << board.numWhiteCaptures << endl; out << secondPlayerName << " score: " << board.numBlackCaptures << endl; } + if (isGameFinished) { + out << "Game is finished, winner: " << PlayerIO::playerToString(winner, isDots) << ", score: " << finalWhiteMinusBlackScore << ", resign: " << boolalpha << isResignation << endl; + } } void BoardHistory::printDebugInfo(ostream& out, const Board& board) const { @@ -860,7 +865,7 @@ bool BoardHistory::isLegal(const Board& board, Loc moveLoc, Player movePla) cons //Ko-moves in the encore that are recapture blocked are interpreted as pass-for-ko, so they are legal if(encorePhase > 0) { - if(moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { + if(moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) { if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc] && board.getChainSize(moveLoc) == 1 && board.getNumLiberties(moveLoc) == 1) return true; Loc koCaptureLoc = board.getKoCaptureLoc(moveLoc,movePla); @@ -886,7 +891,7 @@ bool BoardHistory::isPassForKo(const Board& board, Loc moveLoc, Player movePla) assert(rules.isDots == board.isDots()); if (rules.isDots) return false; - if(encorePhase > 0 && moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC) { + if(encorePhase > 0 && moveLoc >= 0 && moveLoc < Board::MAX_ARR_SIZE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) { if(board.colors[moveLoc] == getOpp(movePla) && koRecapBlocked[moveLoc] && board.getChainSize(moveLoc) == 1 && board.getNumLiberties(moveLoc) == 1) return true; @@ -985,7 +990,7 @@ bool BoardHistory::isLegalTolerant(const Board& board, Loc moveLoc, Player moveP // Allow either side to move during tolerant play, but still check that a player is specified if(movePla != P_BLACK && movePla != P_WHITE) return false; - bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules + constexpr bool multiStoneSuicideLegal = true; // Tolerate suicide and ko regardless of rules constexpr bool ignoreKo = true; if(!isPassForKo(board, moveLoc, movePla) && !board.isLegal(moveLoc,movePla,multiStoneSuicideLegal,ignoreKo)) return false; @@ -1041,7 +1046,9 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo bool isSpightlikeEndingPass = false; bool wasPassForKo = false; - if (rules.isDots) { + if (moveLoc == Board::RESIGN_LOC) { + setWinnerByResignation(getOpp(movePla)); + } else if (rules.isDots) { //Dots game board.playMoveAssumeLegal(moveLoc, movePla); if (moveLoc == Board::PASS_LOC) { @@ -1195,7 +1202,7 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo //Territory scoring - chill 1 point per move in main phase and first encore - if(rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase <= 1 && moveLoc != Board::PASS_LOC && !wasPassForKo) { + if(rules.scoringRule == Rules::SCORING_TERRITORY && encorePhase <= 1 && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC && !wasPassForKo) { if(movePla == P_BLACK) whiteBonusScore += 1.0f; else if(movePla == P_WHITE) @@ -1205,7 +1212,7 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo } //Handicap bonus score - if(movePla == P_WHITE && moveLoc != Board::PASS_LOC) + if(movePla == P_WHITE && moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC) whiteHasMoved = true; if(assumeMultipleStartingBlackMovesAreHandicap && !whiteHasMoved && movePla == P_BLACK && rules.whiteHandicapBonusRule != Rules::WHB_ZERO) { whiteHandicapBonusScore = (float)computeWhiteHandicapBonus(); @@ -1255,7 +1262,7 @@ void BoardHistory::makeBoardMoveAssumeLegal(Board& board, Loc moveLoc, Player mo } //Break long cycles with no-result - if(moveLoc != Board::PASS_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { + if(moveLoc != Board::PASS_LOC && moveLoc != Board::RESIGN_LOC && (encorePhase > 0 || rules.koRule == Rules::KO_SIMPLE)) { if(numberOfKoHashOccurrencesInHistory(koHashHistory[koHashHistory.size()-1], rootKoHashTable) >= 3) { isNoResult = true; isGameFinished = true; diff --git a/cpp/game/common.h b/cpp/game/common.h index 562b12a52..833e3c0cd 100644 --- a/cpp/game/common.h +++ b/cpp/game/common.h @@ -21,6 +21,8 @@ const std::string PLAYER2 = "Player2"; const std::string PLAYER1_SHORT = "P1"; const std::string PLAYER2_SHORT = "P2"; +const std::string RESIGN_STR = "resign"; + // Player typedef int8_t Player; static constexpr Player P_BLACK = 1; diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 0916242f2..7062de431 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -276,7 +276,8 @@ Board::MoveRecord Board::playMoveRecordedDots(const Loc loc, const Player pla) { void Board::playMoveAssumeLegalDots(const Loc loc, const Player pla) { const State originalState = getState(loc); - if (loc == PASS_LOC) { + if (loc == RESIGN_LOC) { + } else if (loc == PASS_LOC) { auto initEmptyBaseInvalidateLocations = vector(); auto bases = vector(); ground(pla, initEmptyBaseInvalidateLocations, bases); @@ -335,7 +336,8 @@ Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool vector initEmptyBaseInvalidateLocations; vector newGroundingLocations; - if (loc == PASS_LOC) { + if (loc == RESIGN_LOC) { + } else if (loc == PASS_LOC) { ground(pla, initEmptyBaseInvalidateLocations, bases); } else { colors[loc] = static_cast(pla | pla << PLACED_PLAYER_SHIFT); @@ -394,6 +396,10 @@ Board::MoveRecord Board::tryPlayMoveRecordedDots(Loc loc, Player pla, const bool } void Board::undoDots(MoveRecord& moveRecord) { + if (moveRecord.loc == RESIGN_LOC) { + return; // Resin doesn't really change the state + } + const bool isGroundingMove = moveRecord.loc == PASS_LOC; for (const Loc& loc : moveRecord.groundingLocations) { diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 67f275632..089d6fbeb 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -8,7 +8,7 @@ int NNPos::xyToPos(int x, int y, int nnXLen) { int NNPos::locToPos(Loc loc, int boardXSize, int nnXLen, int nnYLen) { if(loc == Board::PASS_LOC) return nnXLen * nnYLen; - else if(loc == Board::NULL_LOC) + if(loc == Board::NULL_LOC || loc == Board::RESIGN_LOC) // NN normally shouldn't return a resigning move return nnXLen * (nnYLen + 1); return Location::getY(loc,boardXSize) * nnXLen + Location::getX(loc,boardXSize); } @@ -692,13 +692,13 @@ Loc SymmetryHelpers::getSymLoc(int x, int y, const Board& board, int symmetry) { } Loc SymmetryHelpers::getSymLoc(Loc loc, const Board& board, int symmetry) { - if(loc == Board::NULL_LOC || loc == Board::PASS_LOC) + if(loc == Board::NULL_LOC || loc == Board::PASS_LOC || loc == Board::RESIGN_LOC) return loc; return getSymLoc(Location::getX(loc,board.x_size), Location::getY(loc,board.x_size), board, symmetry); } Loc SymmetryHelpers::getSymLoc(Loc loc, int xSize, int ySize, int symmetry) { - if(loc == Board::NULL_LOC || loc == Board::PASS_LOC) + if(loc == Board::NULL_LOC || loc == Board::PASS_LOC || loc == Board::RESIGN_LOC) return loc; return getSymLoc(Location::getX(loc,xSize), Location::getY(loc,xSize), xSize, ySize, symmetry); } diff --git a/cpp/program/playutils.cpp b/cpp/program/playutils.cpp index f3267c4be..3944d91f7 100644 --- a/cpp/program/playutils.cpp +++ b/cpp/program/playutils.cpp @@ -111,7 +111,7 @@ Loc PlayUtils::chooseRandomLegalMove(const Board& board, const BoardHistory& his Loc locs[Board::MAX_ARR_SIZE]; testAssert(pla == hist.presumedNextMovePla); for(Loc loc = 0; loc < Board::MAX_ARR_SIZE; loc++) { - if(hist.isLegal(board,loc,pla) && loc != banMove) { + if(loc != Board::RESIGN_LOC && hist.isLegal(board,loc,pla) && loc != banMove) { locs[numLegalMoves] = loc; numLegalMoves += 1; } @@ -128,7 +128,7 @@ int PlayUtils::chooseRandomLegalMoves(const Board& board, const BoardHistory& hi Loc locs[Board::MAX_ARR_SIZE]; testAssert(pla == hist.presumedNextMovePla); for(Loc loc = 0; loc < Board::MAX_ARR_SIZE; loc++) { - if(hist.isLegal(board,loc,pla)) { + if(loc != Board::RESIGN_LOC && hist.isLegal(board,loc,pla)) { locs[numLegalMoves] = loc; numLegalMoves += 1; } diff --git a/cpp/search/localpattern.cpp b/cpp/search/localpattern.cpp index d7bf0672d..9081fdf37 100644 --- a/cpp/search/localpattern.cpp +++ b/cpp/search/localpattern.cpp @@ -51,7 +51,7 @@ LocalPatternHasher::~LocalPatternHasher() { Hash128 LocalPatternHasher::getHash(const Board& board, Loc loc, Player pla) const { Hash128 hash = zobristPla[pla]; - if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { + if(loc != Board::PASS_LOC && loc != Board::NULL_LOC && loc != Board::RESIGN_LOC) { vector bases; if (board.isDots()) { vector captures; @@ -91,7 +91,7 @@ Hash128 LocalPatternHasher::getHashWithSym(const Board& board, Loc loc, Player p Player symPla = flipColors ? getOpp(pla) : pla; Hash128 hash = zobristPla[symPla]; - if(loc != Board::PASS_LOC && loc != Board::NULL_LOC) { + if(loc != Board::PASS_LOC && loc != Board::NULL_LOC && loc != Board::RESIGN_LOC) { vector bases; if (board.isDots()) { vector captures; diff --git a/cpp/search/patternbonustable.cpp b/cpp/search/patternbonustable.cpp index 4ce7ae5b3..6fa18d668 100644 --- a/cpp/search/patternbonustable.cpp +++ b/cpp/search/patternbonustable.cpp @@ -52,7 +52,7 @@ Hash128 PatternBonusTable::getHash(Player pla, Loc moveLoc, const Board& board) //We don't want to over-trigger this on a ko that repeats the same pattern over and over //So we just disallow this on ko fight //Also no bonuses for passing. - if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || board.wouldBeKoCapture(moveLoc,pla)) + if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || moveLoc == Board::RESIGN_LOC || board.wouldBeKoCapture(moveLoc,pla)) return Hash128(); Hash128 hash = patternHasher.getHash(board,moveLoc,pla); @@ -87,7 +87,7 @@ void PatternBonusTable::addBonus(Player pla, Loc moveLoc, const Board& board, do //We don't want to over-trigger this on a ko that repeats the same pattern over and over //So we just disallow this on ko fight //Also no bonuses for passing. - if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || board.wouldBeKoCapture(moveLoc,pla)) + if(moveLoc == Board::NULL_LOC || moveLoc == Board::PASS_LOC || moveLoc == Board::RESIGN_LOC || board.wouldBeKoCapture(moveLoc,pla)) return; Hash128 hash = patternHasher.getHashWithSym(board,moveLoc,pla,symmetry,flipColors); diff --git a/cpp/tests/testboardbasic.cpp b/cpp/tests/testboardbasic.cpp index 68b88a516..06535e8ee 100644 --- a/cpp/tests/testboardbasic.cpp +++ b/cpp/tests/testboardbasic.cpp @@ -1912,6 +1912,7 @@ void Tests::runBoardUndoTest() { int suicideCount = 0; int koCaptureCount = 0; int passCount = 0; + int resignCount = 0; int regularMoveCount = 0; auto run = [&](const Board& startBoard, bool multiStoneSuicideLegal) { static const int steps = 1000; @@ -1936,6 +1937,8 @@ void Tests::runBoardUndoTest() { if(loc == Board::PASS_LOC) passCount++; + else if(loc == Board::RESIGN_LOC) + resignCount++; else if(boards[n-1].isSuicide(loc,pla)) suicideCount++; else { @@ -1961,15 +1964,17 @@ void Tests::runBoardUndoTest() { out << endl; out << "regularMoveCount " << regularMoveCount << endl; out << "passCount " << passCount << endl; + out << "resignCount " << resignCount << endl; out << "koCaptureCount " << koCaptureCount << endl; out << "suicideCount " << suicideCount << endl; string expected = R"%%( -regularMoveCount 2446 -passCount 475 -koCaptureCount 24 -suicideCount 79 +regularMoveCount 2116 +passCount 376 +resignCount 443 +koCaptureCount 19 +suicideCount 65 )%%"; expect("Board undo test move counts",out,expected); @@ -2106,10 +2111,10 @@ void Tests::runBoardStressTest() { static const int numBoards = 4; vector boards; Rules rules = Rules::DEFAULT_GO; - boards.push_back(Board(Board::DEFAULT_LEN_X, Board::DEFAULT_LEN_Y, rules)); - boards.push_back(Board(9,16,rules)); - boards.push_back(Board(13,7,rules)); - boards.push_back(Board(4,4,rules)); + boards.emplace_back(Board::DEFAULT_LEN_X, Board::DEFAULT_LEN_Y, rules); + boards.emplace_back(9,16,rules); + boards.emplace_back(13,7,rules); + boards.emplace_back(4,4,rules); bool multiStoneSuicideLegal[4] = {false,false,true,false}; vector copies; Player pla = C_BLACK; @@ -2117,6 +2122,7 @@ void Tests::runBoardStressTest() { int koBanCount = 0; int koCaptureCount = 0; int passCount = 0; + int resignCount = 0; int regularMoveCount = 0; for(int n = 0; n < 20000; n++) { Loc locs[numBoards]; @@ -2179,10 +2185,13 @@ void Tests::runBoardStressTest() { } } else { - if(loc == Board::PASS_LOC) { + if (loc == Board::PASS_LOC || loc == Board::RESIGN_LOC) { testAssert(boardsSeemEqual(copy,board)); testAssert(board.ko_loc == Board::NULL_LOC); - passCount++; + if (loc == Board::PASS_LOC) + passCount++; + else + resignCount++; } else if(copy.isSuicide(loc,pla)) { testAssert(board.colors[loc] == C_EMPTY); @@ -2220,6 +2229,7 @@ void Tests::runBoardStressTest() { out << endl; out << "regularMoveCount " << regularMoveCount << endl; out << "passCount " << passCount << endl; + out << "resignCount " << resignCount << endl; out << "koCaptureCount " << koCaptureCount << endl; out << "koBanCount " << koBanCount << endl; out << "suicideCount " << suicideCount << endl; @@ -2230,8 +2240,9 @@ void Tests::runBoardStressTest() { regularMoveCount 38017 passCount 273 +resignCount 289 koCaptureCount 212 -koBanCount 45 +koBanCount 44 suicideCount 440 Caps 4753 5024 Caps 4821 4733 @@ -2250,7 +2261,7 @@ Caps 4420 4335 for(int y = 0; y Date: Fri, 7 Nov 2025 21:10:07 +0100 Subject: [PATCH 36/42] [GTP] Support for number of moves in `undo` Dottify `final_score` --- cpp/command/gtp.cpp | 74 ++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 5a2cf0fc2..99e56632c 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -656,30 +656,46 @@ struct GTPEngine { } bool play(Loc loc, Player pla) { - assert(bot->getRootHist().rules == currentRules); - bool suc = bot->makeMove(loc,pla,preventEncore); - if(suc) - moveHistory.push_back(Move(loc,pla)); - return suc; + auto& hist = bot->getRootHist(); + assert(hist.rules == currentRules); + if (hist.isGameFinished) { + return false; + } + + if(bot->makeMove(loc,pla,preventEncore)) { + moveHistory.emplace_back(loc,pla); + return true; + } + return false; } - bool undo() { - if(moveHistory.size() <= 0) + bool undo(const int movesCount) { + auto& hist = bot->getRootHist(); + assert(hist.rules == currentRules); + + if (movesCount == 0) { + return true; + } + + const int newMovesCount = static_cast(moveHistory.size() - movesCount); + if (newMovesCount < 0) { return false; - assert(bot->getRootHist().rules == currentRules); + } - vector moveHistoryCopy = moveHistory; + if(moveHistory.empty()) + return false; - Board undoneBoard = initialBoard; + const vector moveHistoryCopy = moveHistory; + + const Board undoneBoard = initialBoard; BoardHistory undoneHist(undoneBoard,initialPla,currentRules,0); - undoneHist.setInitialTurnNumber(bot->getRootHist().initialTurnNumber); - vector emptyMoveHistory; + undoneHist.setInitialTurnNumber(hist.initialTurnNumber); + const vector emptyMoveHistory; setPositionAndRules(initialPla,undoneBoard,undoneHist,initialBoard,initialPla,emptyMoveHistory); - for(int i = 0; i& args) { if (Move move = movesToPlay[moveInd]; !engine->play(move.loc, move.pla)) { responseIsError = true; response = "Illegal move " + PlayerIO::playerToString(move.pla, rootBoard.isDots()) + " " + Location::toString(move.loc, rootBoard); - for (int rollbackMoveInd = 0; rollbackMoveInd < moveInd; rollbackMoveInd++) { - assert(engine->undo()); // Rollback already placed moves - } + assert(engine->undo(moveInd)); // Rollback already placed moves break; } } @@ -3039,7 +3053,7 @@ int MainCmds::gtp(const vector& args) { else if(command == "undo") { int undoCount = 1; - if (pieces.size() > 0) { + if (!pieces.empty()) { if (!Global::tryStringToInt(pieces[0], undoCount) || undoCount < 0) { responseIsError = true; response = "Expected nonnegative integer for undo count"; @@ -3047,12 +3061,10 @@ int MainCmds::gtp(const vector& args) { } if (!responseIsError) { - for (int i = 0; i < undoCount; i++) { - if(!engine->undo()) { - responseIsError = true; - response = "cannot undo"; - break; - } + if (!engine->undo(undoCount)) { + responseIsError = true; + response = "cannot undo"; + break; } } } @@ -3280,12 +3292,12 @@ int MainCmds::gtp(const vector& args) { double finalWhiteMinusBlackScore = 0.0; engine->computeAnticipatedWinnerAndScore(winner,finalWhiteMinusBlackScore); - if(winner == C_EMPTY) + if (winner == C_EMPTY) response = "0"; - else if(winner == C_BLACK) - response = "B+" + Global::strprintf("%.1f",-finalWhiteMinusBlackScore); - else if(winner == C_WHITE) - response = "W+" + Global::strprintf("%.1f",finalWhiteMinusBlackScore); + else if (winner == C_BLACK || winner == C_WHITE) { + double finalScore = winner == C_BLACK ? -finalWhiteMinusBlackScore : finalWhiteMinusBlackScore; + response = PlayerIO::playerToStringShort(winner, engine->currentRules.isDots) + "+" + Global::strprintf("%.1f", finalScore); + } else ASSERT_UNREACHABLE; } From 29f9edd38ecbd09e2bdc1a7e3aec0861b3032d66 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Fri, 14 Nov 2025 00:56:59 +0100 Subject: [PATCH 37/42] Revert "[Python] Refine Dots NN inputs and scoring" It's needed for the next commit. This reverts commit 29e8e69b703eebc5e6359b4853b44cd500d77897. --- .../katago/train/data_processing_pytorch.py | 153 ++++++------------ python/katago/train/model_pytorch.py | 14 +- python/katago/train/modelconfigs.py | 51 +----- 3 files changed, 52 insertions(+), 166 deletions(-) diff --git a/python/katago/train/data_processing_pytorch.py b/python/katago/train/data_processing_pytorch.py index 754a0ef4c..78a50daeb 100644 --- a/python/katago/train/data_processing_pytorch.py +++ b/python/katago/train/data_processing_pytorch.py @@ -1,6 +1,5 @@ import logging import os -from enum import Enum, auto import numpy as np from concurrent.futures import ThreadPoolExecutor @@ -157,133 +156,71 @@ def apply_symmetry(tensor, symm): if symm == 7: return tensor.flip(-2) -class GoSpatialFeature(Enum): - ON_BOARD = 0 - PLA_STONE = 1 - OPP_STONE = 2 - LIBERTIES_1 = 3 - LIBERTIES_2 = 4 - LIBERTIES_3 = 5 - SUPER_KO_BANNED = 6 - KO_RECAP_BLOCKED = 7 - KO_EXTRA = 8 - PREV_1_LOC = 9 - PREV_2_LOC = 10 - PREV_3_LOC = 11 - PREV_4_LOC = 12 - PREV_5_LOC = 13 - LADDER_CAPTURED = 14 - LADDER_CAPTURED_PREVIOUS_1 = 15 - LADDER_CAPTURED_PREVIOUS_2 = 16 - LADDER_WORKING_MOVES = 17 - AREA_PLA = 18 - AREA_OPP = 19 - SECOND_ENCORE_PLA = 20 - SECOND_ENCORE_OPP = 21 - -class GoGlobalFeature(Enum): - PREV_1_LOC_PASS = 0 - PREV_2_LOC_PASS = 1 - PREV_3_LOC_PASS = 2 - PREV_4_LOC_PASS = 3 - PREV_5_LOC_PASS = 4 - KOMI = 5 - KO_RULE_NOT_SIMPLE = 6 - KO_RULE_EXTRA = 7 - SUICIDE = 8 - SCORING_TERRITORY = 9 - TAX_SEKI = 10 - TAX_ALL = 11 - ENCORE_PHASE_1 = 12 - ENCORE_PHASE_2 = 13 - PASS_WOULD_END_PHASE = 14 - PLAYOUT_DOUBLING_ADVANTAGE_FLAG = 15 - PLAYOUT_DOUBLING_ADVANTAGE_VALUE = 16 - HAS_BUTTON = 17 - BOARD_SIZE_KOMI_PARITY = 18 - -class DotsSpatialFeature(Enum): - ON_BOARD = 0 - PLA_ACTIVE = auto() - OPP_ACTIVE = auto() - PLA_PLACED = auto() - OPP_PLACED = auto() - DEAD = auto() - GROUNDED = auto() - PLA_CAPTURES = auto() - OPP_CAPTURES = auto() - PLA_SURROUNDINGS = auto() - OPP_SURROUNDINGS = auto() - PREV_1_LOC = auto() - PREV_2_LOC = auto() - PREV_3_LOC = auto() - PREV_4_LOC = auto() - PREV_5_LOC = auto() - LADDER_CAPTURED = auto() - LADDER_CAPTURED_PREVIOUS_1 = auto() - LADDER_CAPTURED_PREVIOUS_2 = auto() - LADDER_WORKING_MOVES = auto() def build_history_matrices(model_config: modelconfigs.ModelConfig, device): num_bin_features = modelconfigs.get_num_bin_input_features(model_config) - - is_go_game = not modelconfigs.is_dots_game(model_config) - - prev_1_loc = GoSpatialFeature.PREV_1_LOC.value if is_go_game else DotsSpatialFeature.PREV_1_LOC.value - prev_2_loc = GoSpatialFeature.PREV_2_LOC.value if is_go_game else DotsSpatialFeature.PREV_2_LOC.value - prev_3_loc = GoSpatialFeature.PREV_3_LOC.value if is_go_game else DotsSpatialFeature.PREV_3_LOC.value - prev_4_loc = GoSpatialFeature.PREV_4_LOC.value if is_go_game else DotsSpatialFeature.PREV_4_LOC.value - prev_5_loc = GoSpatialFeature.PREV_5_LOC.value if is_go_game else DotsSpatialFeature.PREV_5_LOC.value - - ladder_captured = GoSpatialFeature.LADDER_CAPTURED.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED.value - ladder_captured_previous_1 = GoSpatialFeature.LADDER_CAPTURED_PREVIOUS_1.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED_PREVIOUS_1.value - ladder_captured_previous_2 = GoSpatialFeature.LADDER_CAPTURED_PREVIOUS_2.value if is_go_game else DotsSpatialFeature.LADDER_CAPTURED_PREVIOUS_2.value - ladder_working_moves = GoSpatialFeature.LADDER_WORKING_MOVES.value if is_go_game else DotsSpatialFeature.LADDER_WORKING_MOVES.value - - data = [1.0 for _ in range(num_bin_features)] - - data[prev_1_loc] = 0.0 - data[prev_2_loc] = 0.0 - data[prev_3_loc] = 0.0 - data[prev_4_loc] = 0.0 - data[prev_5_loc] = 0.0 - - data[ladder_captured] = 1.0 - data[ladder_captured_previous_1] = 0.0 - data[ladder_captured_previous_2] = 0.0 - - h_base = torch.diag(torch.tensor(data, device=device, requires_grad=False)) - + assert num_bin_features == 22, "Currently this code is hardcoded for this many features" + + h_base = torch.diag( + torch.tensor( + [ + 1.0, # 0 + 1.0, # 1 + 1.0, # 2 + 1.0, # 3 + 1.0, # 4 + 1.0, # 5 + 1.0, # 6 + 1.0, # 7 + 1.0, # 8 + 0.0, # 9 Location of move 1 turn ago + 0.0, # 10 Location of move 2 turns ago + 0.0, # 11 Location of move 3 turns ago + 0.0, # 12 Location of move 4 turns ago + 0.0, # 13 Location of move 5 turns ago + 1.0, # 14 Ladder-threatened stone + 0.0, # 15 Ladder-threatened stone, 1 turn ago + 0.0, # 16 Ladder-threatened stone, 2 turns ago + 1.0, # 17 + 1.0, # 18 + 1.0, # 19 + 1.0, # 20 + 1.0, # 21 + ], + device=device, + requires_grad=False, + ) + ) # Because we have ladder features that express past states rather than past diffs, # the most natural encoding when we have no history is that they were always the # same, rather than that they were all zero. So rather than zeroing them we have no # history, we add entries in the matrix to copy them over. # By default, without history, the ladder features 15 and 16 just copy over from 14. - h_base[ladder_captured, ladder_captured_previous_1] = 1.0 - h_base[ladder_captured, ladder_captured_previous_2] = 1.0 + h_base[14, 15] = 1.0 + h_base[14, 16] = 1.0 h0 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) # When have the prev move, we enable feature 9 and 15 - h0[prev_1_loc, prev_1_loc] = 1.0 # Enable 9 -> 9 - h0[ladder_captured, ladder_captured_previous_1] = -1.0 # Stop copying 14 -> 15 - h0[ladder_captured, ladder_captured_previous_2] = -1.0 # Stop copying 14 -> 16 - h0[ladder_captured_previous_1, ladder_captured_previous_1] = 1.0 # Enable 15 -> 15 - h0[ladder_captured_previous_1, ladder_captured_previous_2] = 1.0 # Start copying 15 -> 16 + h0[9, 9] = 1.0 # Enable 9 -> 9 + h0[14, 15] = -1.0 # Stop copying 14 -> 15 + h0[14, 16] = -1.0 # Stop copying 14 -> 16 + h0[15, 15] = 1.0 # Enable 15 -> 15 + h0[15, 16] = 1.0 # Start copying 15 -> 16 h1 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) # When have the prevprev move, we enable feature 10 and 16 - h1[prev_2_loc, prev_2_loc] = 1.0 # Enable 10 -> 10 - h1[ladder_captured_previous_1, ladder_captured_previous_2] = -1.0 # Stop copying 15 -> 16 - h1[ladder_captured_previous_2, ladder_captured_previous_2] = 1.0 # Enable 16 -> 16 + h1[10, 10] = 1.0 # Enable 10 -> 10 + h1[15, 16] = -1.0 # Stop copying 15 -> 16 + h1[16, 16] = 1.0 # Enable 16 -> 16 h2 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h2[prev_3_loc, prev_3_loc] = 1.0 + h2[11, 11] = 1.0 h3 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h3[prev_4_loc, prev_4_loc] = 1.0 + h3[12, 12] = 1.0 h4 = torch.zeros(num_bin_features, num_bin_features, device=device, requires_grad=False) - h4[prev_5_loc, prev_5_loc] = 1.0 + h4[13, 13] = 1.0 # (1, n_bin, n_bin) h_base = h_base.reshape((1, num_bin_features, num_bin_features)) diff --git a/python/katago/train/model_pytorch.py b/python/katago/train/model_pytorch.py index cf1209614..4654ac934 100644 --- a/python/katago/train/model_pytorch.py +++ b/python/katago/train/model_pytorch.py @@ -8,7 +8,6 @@ import packaging.version from typing import List, Dict, Optional, Set -from .modelconfigs import get_num_bin_input_features, get_num_global_input_features from ..train import modelconfigs EXTRA_SCORE_DISTR_RADIUS = 60 @@ -1653,22 +1652,19 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: self.activation = "relu" if "activation" not in config else config["activation"] - spatial_features = get_num_bin_input_features(config) - global_features = get_num_global_input_features(config) - if config["initial_conv_1x1"]: - self.conv_spatial = torch.nn.Conv2d(spatial_features, self.c_trunk, kernel_size=1, padding="same", bias=False) + self.conv_spatial = torch.nn.Conv2d(22, self.c_trunk, kernel_size=1, padding="same", bias=False) else: - self.conv_spatial = torch.nn.Conv2d(spatial_features, self.c_trunk, kernel_size=3, padding="same", bias=False) - self.linear_global = torch.nn.Linear(global_features, self.c_trunk, bias=False) + self.conv_spatial = torch.nn.Conv2d(22, self.c_trunk, kernel_size=3, padding="same", bias=False) + self.linear_global = torch.nn.Linear(19, self.c_trunk, bias=False) if "metadata_encoder" in config and config["metadata_encoder"] is not None: self.metadata_encoder = MetadataEncoder(config) else: self.metadata_encoder = None - self.bin_input_shape = [spatial_features, pos_len_x, pos_len_y] - self.global_input_shape = [global_features] + self.bin_input_shape = [22, pos_len_x, pos_len_y] + self.global_input_shape = [19] self.blocks = torch.nn.ModuleList() for block_config in self.block_kind: diff --git a/python/katago/train/modelconfigs.py b/python/katago/train/modelconfigs.py index f5b105eaa..976be7178 100644 --- a/python/katago/train/modelconfigs.py +++ b/python/katago/train/modelconfigs.py @@ -41,35 +41,20 @@ # version = 15 # V7 features, Extra nonlinearity for pass output # version = 16 # V7 features, Q value predictions in the policy head -# version = 17 # V8 features, Extra nonlinearity for pass output, Dots game - def get_version(config: ModelConfig): return config["version"] -def is_dots_game(config: ModelConfig): - version = get_version(config) - if 10 <= version <= 16: - return False - elif version == 17: - return True - else: - assert(False) - def get_num_bin_input_features(config: ModelConfig): version = get_version(config) - if 10 <= version <= 16: + if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: return 22 - elif version == 17: # Dots game - return 20 else: assert(False) def get_num_global_input_features(config: ModelConfig): version = get_version(config) - if 10 <= version <= 16: + if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: return 19 - elif version == 17: # Dots game - return 9 else: assert(False) @@ -214,37 +199,6 @@ def get_num_meta_encoder_input_features(config_or_meta_encoder_version: Union[Mo "v2_size":80, } -b10c128_dots = { - "version":17, - "norm_kind":"fixup", - "bnorm_epsilon": 1e-4, - "bnorm_running_avg_momentum": 0.001, - "initial_conv_1x1": False, - "trunk_num_channels":128, - "mid_num_channels":128, - "gpool_num_channels":32, - "use_attention_pool":False, - "num_attention_pool_heads":4, - "block_kind": [ - ["rconv1","regular"], - ["rconv2","regular"], - ["rconv3","regular"], - ["rconv4","regular"], - ["rconv5","regulargpool"], - ["rconv6","regular"], - ["rconv7","regular"], - ["rconv8","regulargpool"], - ["rconv9","regular"], - ["rconv10","regular"], - ], - "p1_num_channels":32, - "g1_num_channels":32, - "v1_num_channels":32, - "sbv2_num_channels":64, - "num_scorebeliefs":6, - "v2_size":80, -} - b5c192nbt = { "version":15, "norm_kind":"fixup", @@ -1498,7 +1452,6 @@ def get_num_meta_encoder_input_features(config_or_meta_encoder_version: Union[Mo # Small model configs, not too different in inference cost from b10c128 "b10c128": b10c128, - "b10c128_dots": b10c128_dots, "b5c192nbt": b5c192nbt, # Medium model configs, not too different in inference cost from b15c192 From f93c19be94e3fb4838fec5c2e2d2f58c5187ccf6 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sat, 15 Nov 2025 23:09:20 +0100 Subject: [PATCH 38/42] Make number of spatial and global features identical to KataGo last model version 7 (22 and 19) Embed info about Dots Game into model fix #8 --- cpp/command/evalsgf.cpp | 11 +- cpp/command/selfplay.cpp | 6 +- cpp/command/writetrainingdata.cpp | 2 +- cpp/dataio/trainingwrite.cpp | 19 ++-- cpp/dataio/trainingwrite.h | 4 +- cpp/neuralnet/desc.cpp | 35 +++++-- cpp/neuralnet/desc.h | 7 +- cpp/neuralnet/modelversion.cpp | 147 ++++++++++++++++++--------- cpp/neuralnet/modelversion.h | 13 +-- cpp/neuralnet/nneval.cpp | 14 +-- cpp/neuralnet/nneval.h | 1 + cpp/neuralnet/nninputs.cpp | 52 +++++----- cpp/neuralnet/nninputs.h | 103 +++++++++++-------- cpp/neuralnet/nninputsdots.cpp | 54 ++++------ cpp/neuralnet/nninterface.h | 3 +- cpp/neuralnet/openclbackend.cpp | 19 ++-- cpp/neuralnet/opencltuner.cpp | 13 +-- cpp/neuralnet/trtbackend.cpp | 17 ++-- cpp/tests/testnninputs.cpp | 10 +- python/export_model_pytorch.py | 6 +- python/katago/train/load_model.py | 7 +- python/katago/train/model_pytorch.py | 18 +++- python/katago/train/modelconfigs.py | 4 +- python/selfplay/train.sh | 1 + python/test.py | 8 +- python/train.py | 11 +- 26 files changed, 351 insertions(+), 234 deletions(-) diff --git a/cpp/command/evalsgf.cpp b/cpp/command/evalsgf.cpp index 9de38853c..3aa11a65d 100644 --- a/cpp/command/evalsgf.cpp +++ b/cpp/command/evalsgf.cpp @@ -665,8 +665,15 @@ int MainCmds::evalsgf(const vector& args) { int nnXLen = nnEval->getNNXLen(); int nnYLen = nnEval->getNNYLen(); int modelVersion = nnEval->getModelVersion(); - int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + + bool nnEvalIsInDotsMode = nnEval->getDotsGame(); + assert(nnEvalIsInDotsMode == sgf->isDots); + if (nnEvalIsInDotsMode != sgf->isDots) { + cout << "SGF and model mismatch (Go and Dots games)"; + } + + int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, nnEvalIsInDotsMode); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, nnEvalIsInDotsMode); NumpyBuffer binaryInputNCHW(std::vector({1,numSpatialFeatures,nnXLen,nnYLen})); NumpyBuffer globalInputNC(std::vector({1,numGlobalFeatures})); diff --git a/cpp/command/selfplay.cpp b/cpp/command/selfplay.cpp index 03551077d..0d771c4d8 100644 --- a/cpp/command/selfplay.cpp +++ b/cpp/command/selfplay.cpp @@ -100,7 +100,7 @@ int MainCmds::selfplay(const vector& args) { const int inputsVersion = cfg.contains("inputsVersion") ? cfg.getInt("inputsVersion",0,10000) : - NNModelVersion::getInputsVersion(dotsGame ? NNModelVersion::defaultModelVersionForDots : NNModelVersion::defaultModelVersion); + NNModelVersion::getInputsVersion(NNModelVersion::defaultModelVersion, dotsGame); //Max number of games that we will allow to be queued up and not written out const int maxDataQueueSize = cfg.getInt("maxDataQueueSize",1,1000000); const int maxRowsPerTrainFile = cfg.getInt("maxRowsPerTrainFile",1,100000000); @@ -141,7 +141,7 @@ int MainCmds::selfplay(const vector& args) { auto loadLatestNeuralNetIntoManager = [inputsVersion,&manager,maxRowsPerTrainFile,firstFileRandMinProp,dataBoardLenX,dataBoardLenY, &modelsDir,&outputDir,&logger,&cfg,numGameThreads, - minBoardXSizeUsed,maxBoardXSizeUsed,minBoardYSizeUsed,maxBoardYSizeUsed](const string* lastNetName) -> bool { + minBoardXSizeUsed,maxBoardXSizeUsed,minBoardYSizeUsed,maxBoardYSizeUsed,dotsGame](const string* lastNetName) -> bool { string modelName; string modelFile; @@ -217,7 +217,7 @@ int MainCmds::selfplay(const vector& args) { //Note that this inputsVersion passed here is NOT necessarily the same as the one used in the neural net self play, it //simply controls the input feature version for the written data auto tdataWriter = new TrainingDataWriter( - tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLenX, dataBoardLenY, Global::uint64ToHexString(rand.nextUInt64())); + tdataOutputDir, nullptr, inputsVersion, maxRowsPerTrainFile, firstFileRandMinProp, dataBoardLenX, dataBoardLenY, Global::uint64ToHexString(rand.nextUInt64()), 1, dotsGame); ofstream* sgfOut = nullptr; if(sgfOutputDir.length() > 0) { sgfOut = new ofstream(); diff --git a/cpp/command/writetrainingdata.cpp b/cpp/command/writetrainingdata.cpp index a0527b9ca..ce988fa38 100644 --- a/cpp/command/writetrainingdata.cpp +++ b/cpp/command/writetrainingdata.cpp @@ -691,7 +691,7 @@ int MainCmds::writetrainingdata(const vector& args) { if(dataBoardLen > Board::MAX_LEN) throw StringError("dataBoardLen > maximum board len, must recompile to increase"); - static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); + static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); const int inputsVersion = 7; const int numBinaryChannels = NNInputs::NUM_FEATURES_SPATIAL_V7; const int numGlobalChannels = NNInputs::NUM_FEATURES_GLOBAL_V7; diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index 823e4337b..5f26c9411 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -483,8 +483,8 @@ void TrainingWriteBuffers::addRow( SGFMetadata* sgfMeta, Rand& rand ) { - static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); - if(inputsVersion < 3 || inputsVersion > 8) + static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); + if(inputsVersion < 3 || inputsVersion > 7) throw StringError("Training write buffers: Does not support input version: " + Global::intToString(inputsVersion)); int posArea = dataXLen*dataYLen; @@ -505,8 +505,8 @@ void TrainingWriteBuffers::addRow( float* rowBin = binaryInputNCHWUnpacked; float* rowGlobal = globalInputNC.data + curRows * numGlobalChannels; - assert(NNInputs::getNumberOfSpatialFeatures(inputsVersion) == numBinaryChannels); - assert(NNInputs::getNumberOfGlobalFeatures(inputsVersion) == numGlobalChannels); + assert(NNInputs::getNumberOfSpatialFeatures(inputsVersion, hist.rules.isDots) == numBinaryChannels); + assert(NNInputs::getNumberOfGlobalFeatures(inputsVersion, hist.rules.isDots) == numGlobalChannels); NNInputs::fillRowVN(inputsVersion, board, hist, nextPlayer, nnInputParams, dataXLen, dataYLen, inputsUseNHWC, rowBin, rowGlobal); @@ -519,6 +519,8 @@ void TrainingWriteBuffers::addRow( //Vector for global targets and metadata float* rowGlobal = globalTargetsNC.data + curRows * GLOBAL_TARGET_NUM_CHANNELS; + rowGlobal[23] = hist.rules.isDots ? 1.0f : 0.0f; // Previously unused -> engage for Dots game + //Target weight for the whole row rowGlobal[25] = targetWeight; @@ -592,8 +594,6 @@ void TrainingWriteBuffers::addRow( rowGlobal[22] = (float)sum; } - //Unused - rowGlobal[23] = 0.0f; rowGlobal[24] = (float)(1.0f - tdValueTargetWeight); rowGlobal[30] = (float)policySurprise; rowGlobal[31] = (float)policyEntropy; @@ -960,13 +960,14 @@ TrainingDataWriter::TrainingDataWriter(const string& outputDir, ostream* debugOu const int dataXLen, const int dataYLen, const string& randSeed, - const int onlyWriteEvery) + const int onlyWriteEvery, + const bool dotsGame) :outputDir(outputDir),inputsVersion(inputsVersion),rand(randSeed),writeBuffers(nullptr),debugOut(debugOut),debugOnlyWriteEvery(onlyWriteEvery),rowCount(0) { //Note that this inputsVersion is for data writing, it might be different than the inputsVersion used // to feed into a model during selfplay - const int numBinaryChannels = NNInputs::getNumberOfSpatialFeatures(inputsVersion); - const int numGlobalChannels = NNInputs::getNumberOfGlobalFeatures(inputsVersion); + const int numBinaryChannels = NNInputs::getNumberOfSpatialFeatures(inputsVersion, dotsGame); + const int numGlobalChannels = NNInputs::getNumberOfGlobalFeatures(inputsVersion, dotsGame); constexpr bool hasMetadataInput = false; writeBuffers = new TrainingWriteBuffers( diff --git a/cpp/dataio/trainingwrite.h b/cpp/dataio/trainingwrite.h index 2034032fd..ffd0b32b5 100644 --- a/cpp/dataio/trainingwrite.h +++ b/cpp/dataio/trainingwrite.h @@ -173,7 +173,7 @@ struct TrainingWriteBuffers { //C20: Actual final score, from the perspective of the player to move, adjusted for draw utility, zero if C27 is zero. //C21: Lead in points, number of points to make the game fair, zero if C29 is zero. //C22: Expected arrival time of WL variance. - //C23: Unused + //C23: 1.0 if Dots game //C24: 1.0 minus weight assigned to td value targets //C25 Weight multiplier for row as a whole @@ -311,7 +311,7 @@ struct TrainingWriteBuffers { class TrainingDataWriter { public: - TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1); + TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1, bool dotsGame = false); ~TrainingDataWriter(); void writeGame(const FinishedGameData& data); diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 5ae003cec..fa40382e5 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -1557,6 +1557,9 @@ ModelPostProcessParams::~ModelPostProcessParams() ModelDesc::ModelDesc() : modelVersion(-1), + maxLenX(Board::MAX_LEN_X), + maxLenY(Board::MAX_LEN_Y), + isDotsGame(false), numInputChannels(0), numInputGlobalChannels(0), numInputMetaChannels(0), @@ -1564,14 +1567,31 @@ ModelDesc::ModelDesc() numValueChannels(0), numScoreValueChannels(0), numOwnershipChannels(0), - metaEncoderVersion(0), - postProcessParams() -{} + metaEncoderVersion(0) {} -ModelDesc::ModelDesc(istream& in, const string& sha256_, bool binaryFloats) { +ModelDesc::ModelDesc(istream& in, const string& sha256_, const bool binaryFloats) { in >> name; sha256 = sha256_; - in >> modelVersion; + + string str; + in >> str; + + if (const auto pieces = Global::split(str, ','); pieces.size() > 1) { + modelVersion = Global::stringToInt(pieces[0]); + maxLenX = Global::stringToInt(pieces[1]); + maxLenY = Global::stringToInt(pieces[2]); + const auto gamePieces = Global::split(pieces[3], ';'); + if (gamePieces.size() > 1) { + throw StringError("KataGo currently supports only single-game mode. The `" + name + "` is a mixed model for games: " + Global::concat(gamePieces, ", ")); + } + isDotsGame = Global::isEqualCaseInsensitive(gamePieces[0], DOTS_KEY); + } else { + modelVersion = Global::stringToInt(str); + maxLenX = -1; // It's unclear the exact training size of old modules. Typically, it's 19, but not always + maxLenY = -1; + isDotsGame = false; + } + if(in.fail()) throw StringError("Model failed to parse name or version. Is this a valid model file? You probably specified the wrong file."); @@ -1722,10 +1742,13 @@ ModelDesc::ModelDesc(ModelDesc&& other) { *this = std::move(other); } -ModelDesc& ModelDesc::operator=(ModelDesc&& other) { +ModelDesc& ModelDesc::operator=(ModelDesc&& other) noexcept { name = std::move(other.name); sha256 = std::move(other.sha256); modelVersion = other.modelVersion; + maxLenX = other.maxLenX; + maxLenY = other.maxLenY; + isDotsGame = other.isDotsGame; numInputChannels = other.numInputChannels; numInputGlobalChannels = other.numInputGlobalChannels; numInputMetaChannels = other.numInputMetaChannels; diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index b44d00c59..993b39557 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -347,6 +347,9 @@ struct ModelDesc { std::string name; std::string sha256; int modelVersion; + int maxLenX; // Currently unused, but initialize it just in case for the future + int maxLenY; + bool isDotsGame; int numInputChannels; int numInputGlobalChannels; int numInputMetaChannels; @@ -365,13 +368,13 @@ struct ModelDesc { ModelDesc(); ~ModelDesc(); - ModelDesc(std::istream& in, const std::string& sha256, bool binaryFloats); + ModelDesc(std::istream& in, const std::string& sha256_, bool binaryFloats); ModelDesc(ModelDesc&& other); ModelDesc(const ModelDesc&) = delete; ModelDesc& operator=(const ModelDesc&) = delete; - ModelDesc& operator=(ModelDesc&& other); + ModelDesc& operator=(ModelDesc&& other) noexcept ; void iterConvLayers(std::function f) const; int maxConvChannels(int convXSize, int convYSize) const; diff --git a/cpp/neuralnet/modelversion.cpp b/cpp/neuralnet/modelversion.cpp index a829e743a..05e54f4a6 100644 --- a/cpp/neuralnet/modelversion.cpp +++ b/cpp/neuralnet/modelversion.cpp @@ -22,68 +22,115 @@ //15 = V7 features, Extra nonlinearity for pass output //16 = V7 features, Q value predictions in the policy head -//17 = V8 features (Dots game) - -static void fail(int modelVersion) { - throw StringError("NNModelVersion: Model version not currently implemented or supported: " + Global::intToString(modelVersion)); +static void fail(const int modelVersion, const bool dotsGame) { + throw StringError("NNModelVersion: Model version not currently implemented or supported: " + Global::intToString(modelVersion) + + (dotsGame ? " (Dots game)" : "")); } -static_assert(NNModelVersion::oldestModelVersionImplemented == 3, ""); -static_assert(NNModelVersion::oldestInputsVersionImplemented == 3, ""); -static_assert(NNModelVersion::latestModelVersionImplemented == 17, ""); -static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); +static_assert(NNModelVersion::oldestModelVersionImplemented == 3); +static_assert(NNModelVersion::oldestInputsVersionImplemented == 3); +static_assert(NNModelVersion::latestModelVersionImplemented == 16); +static_assert(NNModelVersion::latestInputsVersionImplemented == 7); -int NNModelVersion::getInputsVersion(int modelVersion) { - if (modelVersion == defaultModelVersionForDots) - return dotsInputsVersion; - if(modelVersion >= 8 && modelVersion <= 16) - return 7; - else if(modelVersion == 7) - return 6; - else if(modelVersion == 6) - return 5; - else if(modelVersion == 5) - return 4; - else if(modelVersion == 3 || modelVersion == 4) - return 3; +int NNModelVersion::getInputsVersion(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return 3; + } + break; + case 5: + if (!dotsGame) { + return 4; + } + break; + case 6: + if (!dotsGame) { + return 5; + } + break; + case 7: + if (!dotsGame) { + return 6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return 7; + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } -int NNModelVersion::getNumSpatialFeatures(int modelVersion) { - if(modelVersion == defaultModelVersionForDots) - return NNInputs::NUM_FEATURES_SPATIAL_V_DOTS; - if(modelVersion >= 8 && modelVersion <= 16) - return NNInputs::NUM_FEATURES_SPATIAL_V7; - if(modelVersion == 7) - return NNInputs::NUM_FEATURES_SPATIAL_V6; - if(modelVersion == 6) - return NNInputs::NUM_FEATURES_SPATIAL_V5; - if(modelVersion == 5) - return NNInputs::NUM_FEATURES_SPATIAL_V4; - if(modelVersion == 3 || modelVersion == 4) - return NNInputs::NUM_FEATURES_SPATIAL_V3; +int NNModelVersion::getNumSpatialFeatures(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V3; + } + break; + case 5: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V4; + } + break; + case 6: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V5; + } + break; + case 7: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_SPATIAL_V6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return dotsGame ? NNInputs::NUM_FEATURES_SPATIAL_V7_DOTS : NNInputs::NUM_FEATURES_SPATIAL_V7; + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } -int NNModelVersion::getNumGlobalFeatures(int modelVersion) { - if(modelVersion == defaultModelVersionForDots) - return NNInputs::NUM_FEATURES_GLOBAL_V_DOTS; - if(modelVersion >= 8 && modelVersion <= 16) - return NNInputs::NUM_FEATURES_GLOBAL_V7; - else if(modelVersion == 7) - return NNInputs::NUM_FEATURES_GLOBAL_V6; - else if(modelVersion == 6) - return NNInputs::NUM_FEATURES_GLOBAL_V5; - else if(modelVersion == 5) - return NNInputs::NUM_FEATURES_GLOBAL_V4; - else if(modelVersion == 3 || modelVersion == 4) - return NNInputs::NUM_FEATURES_GLOBAL_V3; +int NNModelVersion::getNumGlobalFeatures(const int modelVersion, const bool dotsGame) { + switch(modelVersion) { + case 3: + case 4: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V3; + } + break; + case 5: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V4; + } + break; + case 6: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V5; + } + break; + case 7: + if (!dotsGame) { + return NNInputs::NUM_FEATURES_GLOBAL_V6; + } + break; + default: + if (modelVersion <= latestModelVersionImplemented) { + return dotsGame ? NNInputs::NUM_FEATURES_GLOBAL_V7_DOTS : NNInputs::NUM_FEATURES_GLOBAL_V7; + } + break; + } - fail(modelVersion); + fail(modelVersion, dotsGame); return -1; } diff --git a/cpp/neuralnet/modelversion.h b/cpp/neuralnet/modelversion.h index 6eeac8c7a..3fe754089 100644 --- a/cpp/neuralnet/modelversion.h +++ b/cpp/neuralnet/modelversion.h @@ -4,23 +4,20 @@ // Model versions namespace NNModelVersion { - constexpr int latestModelVersionImplemented = 17; - constexpr int latestInputsVersionImplemented = 8; - constexpr int latestGoInputsVersion = 7; - constexpr int dotsInputsVersion = latestInputsVersionImplemented; + constexpr int latestModelVersionImplemented = 16; + constexpr int latestInputsVersionImplemented = 7; constexpr int defaultModelVersion = 16; - constexpr int defaultModelVersionForDots = 17; constexpr int oldestModelVersionImplemented = 3; constexpr int oldestInputsVersionImplemented = 3; // Which V* feature version from NNInputs does a given model version consume? - int getInputsVersion(int modelVersion); + int getInputsVersion(int modelVersion, bool dotsGame); // Convenience functions, feeds forward the number of features and the size of // the row vector that the net takes as input - int getNumSpatialFeatures(int modelVersion); - int getNumGlobalFeatures(int modelVersion); + int getNumSpatialFeatures(int modelVersion, bool dotsGame); + int getNumGlobalFeatures(int modelVersion, bool dotsGame); // SGF metadata encoder input versions int getNumInputMetaChannels(int metaEncoderVersion); diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index cd1082f2f..b5bf2ccad 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -139,7 +139,6 @@ NNEvaluator::NNEvaluator( const ModelDesc& desc = NeuralNet::getModelDesc(loadedModel); internalModelName = desc.name; modelVersion = desc.modelVersion; - inputsVersion = NNModelVersion::getInputsVersion(modelVersion); numInputMetaChannels = desc.numInputMetaChannels; postProcessParams = desc.postProcessParams; computeContext = NeuralNet::createComputeContext( @@ -150,9 +149,9 @@ NNEvaluator::NNEvaluator( } else { internalModelName = "random"; - modelVersion = dotsGame ? NNModelVersion::defaultModelVersionForDots : NNModelVersion::defaultModelVersion; - inputsVersion = NNModelVersion::getInputsVersion(modelVersion); + modelVersion = NNModelVersion::defaultModelVersion; } + inputsVersion = NNModelVersion::getInputsVersion(modelVersion, dotsGame); //Reserve a decent amount above the batch size so that allocation is unlikely. queryQueue.reserve(maxBatchSize * 4 * gpuIdxByServerThread.size()); @@ -286,6 +285,9 @@ int NNEvaluator::getNNYLen() const { int NNEvaluator::getModelVersion() const { return modelVersion; } +bool NNEvaluator::getDotsGame() const { + return dotsGame; +} double NNEvaluator::getTrunkSpatialConvDepth() const { return NeuralNet::getModelDesc(loadedModel).getTrunkSpatialConvDepth(); } @@ -568,7 +570,7 @@ void NNEvaluator::serve( } } - NeuralNet::getOutput(gpuHandle, buf.inputBuffers, numRows, resultBufs.data(), outputBuf); + NeuralNet::getOutput(gpuHandle, buf.inputBuffers, numRows, resultBufs.data(), outputBuf, dotsGame); assert(outputBuf.size() == numRows); m_numRowsProcessed.fetch_add(numRows, std::memory_order_relaxed); @@ -778,10 +780,10 @@ void NNEvaluator::evaluate( buf.boardYSizeForServer = board.y_size; if(!debugSkipNeuralNet) { - const int rowSpatialLen = NNModelVersion::getNumSpatialFeatures(modelVersion) * nnXLen * nnYLen; + const int rowSpatialLen = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame) * nnXLen * nnYLen; if(buf.rowSpatialBuf.size() < rowSpatialLen) buf.rowSpatialBuf.resize(rowSpatialLen); - const int rowGlobalLen = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int rowGlobalLen = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); if(buf.rowGlobalBuf.size() < rowGlobalLen) buf.rowGlobalBuf.resize(rowGlobalLen); const int rowMetaLen = numInputMetaChannels; diff --git a/cpp/neuralnet/nneval.h b/cpp/neuralnet/nneval.h index 6e87f8f0f..a13969c63 100644 --- a/cpp/neuralnet/nneval.h +++ b/cpp/neuralnet/nneval.h @@ -125,6 +125,7 @@ class NNEvaluator { int getNNXLen() const; int getNNYLen() const; int getModelVersion() const; + bool getDotsGame() const; double getTrunkSpatialConvDepth() const; enabled_t getUsingFP16Mode() const; enabled_t getUsingNHWCMode() const; diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 089d6fbeb..0e5db110b 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -1061,37 +1061,40 @@ Hash128 NNInputs::getHash( return hash; } -int NNInputs::getNumberOfSpatialFeatures(int version) { +int NNInputs::getNumberOfSpatialFeatures(const int version, const bool isDots) { switch(version) { - case 3: return NUM_FEATURES_SPATIAL_V3; - case 4: return NUM_FEATURES_SPATIAL_V4; - case 5: return NUM_FEATURES_SPATIAL_V5; - case 6: return NUM_FEATURES_SPATIAL_V6; - case 7: return NUM_FEATURES_SPATIAL_V7; - case 8: return NUM_FEATURES_SPATIAL_V_DOTS; - default: throw std::range_error("Invalid input version: " + to_string(version)); - } + case 3: if (!isDots) return NUM_FEATURES_SPATIAL_V3; break; + case 4: if (!isDots) return NUM_FEATURES_SPATIAL_V4; break; + case 5: if (!isDots) return NUM_FEATURES_SPATIAL_V5; break; + case 6: if (!isDots) return NUM_FEATURES_SPATIAL_V6; break; + case 7: return isDots ? NUM_FEATURES_SPATIAL_V7_DOTS : NUM_FEATURES_SPATIAL_V7; + default: break; + } + throw std::range_error("Invalid input version: " + to_string(version) + (isDots ? " (Dots game)" : "")); } -int NNInputs::getNumberOfGlobalFeatures(int version) { +int NNInputs::getNumberOfGlobalFeatures(const int version, const bool isDots) { switch(version) { - case 3: return NUM_FEATURES_GLOBAL_V3; - case 4: return NUM_FEATURES_GLOBAL_V4; - case 5: return NUM_FEATURES_GLOBAL_V5; - case 6: return NUM_FEATURES_GLOBAL_V6; - case 7: return NUM_FEATURES_GLOBAL_V7; - case 8: return NUM_FEATURES_GLOBAL_V_DOTS; - default: throw std::range_error("Invalid input version: " + to_string(version)); - } + case 3: if (!isDots) return NUM_FEATURES_GLOBAL_V3; break; + case 4: if (!isDots) return NUM_FEATURES_GLOBAL_V4; break; + case 5: if (!isDots) return NUM_FEATURES_GLOBAL_V5; break; + case 6: if (!isDots) return NUM_FEATURES_GLOBAL_V6; break; + case 7: return isDots ? NUM_FEATURES_GLOBAL_V7_DOTS : NUM_FEATURES_GLOBAL_V7; + default: break; + } + throw std::range_error("Invalid input version: " + to_string(version) + (isDots ? " (Dots game)" : "")); } // Generic filler void NNInputs::fillRowVN( - int version, + const int version, const Board& board, const BoardHistory& hist, Player nextPlayer, const MiscNNInputParams& nnInputParams, - int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal + const int nnXLen, + const int nnYLen, + const bool useNHWC, + float* rowBin, float* rowGlobal ) { switch(version) { case 3: @@ -1107,10 +1110,11 @@ void NNInputs::fillRowVN( fillRowV6(board, hist, nextPlayer, nnInputParams, nnXLen, nnYLen, useNHWC, rowBin, rowGlobal); break; case 7: - fillRowV7(board, hist, nextPlayer, nnInputParams, nnXLen, nnYLen, useNHWC, rowBin, rowGlobal); - break; - case 8: - fillRowVDots(board, hist, nextPlayer, nnInputParams, nnXLen, nnYLen, useNHWC,rowBin,rowGlobal); + if (!hist.rules.isDots) { + fillRowV7(board, hist, nextPlayer, nnInputParams, nnXLen, nnYLen, useNHWC, rowBin, rowGlobal); + } else { + fillRowV7Dots(board, hist, nextPlayer, nnInputParams, nnXLen, nnYLen, useNHWC,rowBin,rowGlobal); + } break; default: throw std::range_error("Invalid input version: " + to_string(version)); diff --git a/cpp/neuralnet/nninputs.h b/cpp/neuralnet/nninputs.h index 046677c6d..16c9266e3 100644 --- a/cpp/neuralnet/nninputs.h +++ b/cpp/neuralnet/nninputs.h @@ -60,6 +60,60 @@ struct MiscNNInputParams { }; namespace NNInputs { + enum class DotsSpatialFeature { + OnBoard_0, // 0: On board + PlayerActive_1, // 1: Pla stone + PlayerOppActive_2, // 2: Opp stone + PlayerPlaced_3, // 3: 1 libs + PlayerOppPlaced_4, // 4: 2 libs + DeadDots_5, // 5: 3 libs + Reserved_6, // 6: superKoBanned (in the encore, no-second-ko-capture locations, encore ko prohibitions where we have to pass for ko) + Reserved_7, // 7: koRecapBlocked (in the encore, no-second-ko-capture locations, encore ko prohibitions where we have to pass for ko) + Grounded_8, // 8: unused? (in the encore, no-second-ko-capture locations, encore ko prohibitions where we have to pass for ko) + + Prev1Loc_9, // 9: prev 1 Loc history + Prev2Loc_10, // 10: prev 2 Loc history + Prev3Loc_11, // 11: prev 3 Loc history + Prev4Loc_12, // 12: prev 4 Loc history + Prev5Loc_13, // 13: prev 5 Loc history + + LadderCaptured_14, // 14: ladder captured + LadderCapturedPrevious_15, // 15: ladder captured prev + LadderCapturedPrevious2_16, // 16: ladder captured prev 2 + LadderWorkingMoves_17, // 17: ladder working moves + + PlayerCaptures_18, // 18: pla current territory, not counting group tax + PlayerOppCaptures_19, // 19: opp current territory, not counting group tax + PlayerSurroundings_20, // 20: pla second encore starting stones + PlayerOppSurroundings_21, // 21: opp second encore starting stones + + COUNT, // = 22 + }; + + enum class DotsGlobalFeature { + Reserved_0, // 0: prev loc 1 PASS + Reserved_1, // 1: prev loc 2 PASS + Reserved_2, // 2: prev loc 3 PASS + Reserved_3, // 3: prev loc 4 PASS + Reserved_4, // 4: prev loc 5 PASS + Komi_5, // 5: Komi + Reserved_6, // 6: Ko rule (KO_POSITIONAL, KO_SPIGHT) + Reserved_7, // 7: Ko rule (KO_SITUATIONAL) + Suicide_8, // 8: Suicide + Reserved_9, // 9: Scoring + TaxIsEnabled_10, // 10: TAX is enabled + TaxAll_11, // 11: TAX_ALL + Reserved_12, // 12: encore phase > 0 + Reserved_13, // 13: encore phase > 1 + WinByGrounding_14, // 14: does a pass end the current phase given the ruleset and history? + PlayoutDoublingAdvantageIsEnabled_15, // 15: playoutDoublingAdvantage is enabled (handicap play) + PlayoutDoublingAdvantage_16, // 16: playoutDoublingAdvantage + CaptureEmpty_17, // 17: button + FieldSizeKomiParity_18, // 18: board size komi parity + + COUNT, // = 19 + }; + const int NUM_FEATURES_SPATIAL_V3 = 22; const int NUM_FEATURES_GLOBAL_V3 = 14; @@ -75,54 +129,19 @@ namespace NNInputs { const int NUM_FEATURES_SPATIAL_V7 = 22; const int NUM_FEATURES_GLOBAL_V7 = 19; - // TODO: normalize with Go spatial and global features (e.g. 22 and 19) - enum class DotsSpatialFeature { - OnBoard, // 0 - PlayerActive, // 1 - PlayerOppActive, // 2 - PlayerPlaced, // 1 - PlayerOppPlaced, // 2 - DeadDots, // Actual scoring - Grounded, // Analogue of territory (18,19) - PlayerCaptures, // (3,4,5) - PlayerOppCaptures, // (3,4,5) - PlayerSurroundings, // (3,4,5) - PlayerOppSurroundings, // (3,4,5) - Prev1Loc, // 9 - Prev2Loc, // 10 - Prev3Loc, // 11 - Prev4Loc, // 12 - Prev5Loc, // 13 - LadderCaptured, // 14, - LadderCapturedPrevious, // 15, - LadderCapturedPrevious2, // 16, - LadderWorkingMoves, // 17, - COUNT, // = 20 - }; - - enum class DotsGlobalFeature { - Komi, // 5 - Suicide, // 8 - CaptureEmpty, // 9, 10, 11 - StartPosCross, // (15) - StartPosCross2, // (15) - StartPosCross4, // (15) - StartPosIsRandom, // (15) - WinByGrounding, // Train grounding - FieldSizeKomiParity, // Not sure what this is (18) - COUNT, // = 9 - }; + constexpr int NUM_FEATURES_SPATIAL_V7_DOTS = static_cast(DotsSpatialFeature::COUNT); + constexpr int NUM_FEATURES_GLOBAL_V7_DOTS = static_cast(DotsGlobalFeature::COUNT); - constexpr int NUM_FEATURES_SPATIAL_V_DOTS = static_cast(DotsSpatialFeature::COUNT); - constexpr int NUM_FEATURES_GLOBAL_V_DOTS = static_cast(DotsGlobalFeature::COUNT); + static_assert(NUM_FEATURES_SPATIAL_V7_DOTS == NUM_FEATURES_SPATIAL_V7); // Might be changed later if needed + static_assert(NUM_FEATURES_GLOBAL_V7_DOTS == NUM_FEATURES_GLOBAL_V7); Hash128 getHash( const Board& board, const BoardHistory& boardHistory, Player nextPlayer, const MiscNNInputParams& nnInputParams ); - int getNumberOfSpatialFeatures(int version); - int getNumberOfGlobalFeatures(int version); + int getNumberOfSpatialFeatures(int version, bool isDots); + int getNumberOfGlobalFeatures(int version, bool isDots); void fillRowVN( int version, @@ -150,7 +169,7 @@ namespace NNInputs { const Board& board, const BoardHistory& boardHistory, Player nextPlayer, const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal ); - void fillRowVDots( + void fillRowV7Dots( const Board& board, const BoardHistory& hist, Player nextPlayer, const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index 6e517bc11..b08e173e1 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -2,7 +2,7 @@ using namespace std; -void NNInputs::fillRowVDots( +void NNInputs::fillRowV7Dots( const Board& board, const BoardHistory& hist, Player nextPlayer, const MiscNNInputParams& nnInputParams, int nnXLen, int nnYLen, bool useNHWC, float* rowBin, float* rowGlobal @@ -11,8 +11,8 @@ void NNInputs::fillRowVDots( assert(nnYLen <= NNPos::MAX_BOARD_LEN_Y); assert(board.x_size <= nnXLen); assert(board.y_size <= nnYLen); - std::fill_n(rowBin, NUM_FEATURES_SPATIAL_V_DOTS * nnXLen * nnYLen,false); - std::fill_n(rowGlobal, NUM_FEATURES_GLOBAL_V_DOTS, 0.0f); + std::fill_n(rowBin, NUM_FEATURES_SPATIAL_V7_DOTS * nnXLen * nnYLen,false); + std::fill_n(rowGlobal, NUM_FEATURES_GLOBAL_V7_DOTS, 0.0f); const Player pla = nextPlayer; const Player opp = getOpp(pla); @@ -23,7 +23,7 @@ void NNInputs::fillRowVDots( int posStride; if(useNHWC) { featureStride = 1; - posStride = NUM_FEATURES_SPATIAL_V_DOTS; + posStride = NUM_FEATURES_SPATIAL_V7_DOTS; } else { featureStride = nnXLen * nnYLen; @@ -50,50 +50,50 @@ void NNInputs::fillRowVDots( const int pos = NNPos::xyToPos(x,y,nnXLen); const Loc loc = Location::getLoc(x,y,xSize); - setSpatial(pos, DotsSpatialFeature::OnBoard); + setSpatial(pos, DotsSpatialFeature::OnBoard_0); const State state = board.getState(loc); const Color activeColor = getActiveColor(state); const Color placedColor = getPlacedDotColor(state); if (activeColor == pla) - setSpatial(pos, DotsSpatialFeature::PlayerActive); + setSpatial(pos, DotsSpatialFeature::PlayerActive_1); else if (activeColor == opp) - setSpatial(pos, DotsSpatialFeature::PlayerOppActive); + setSpatial(pos, DotsSpatialFeature::PlayerOppActive_2); else assert(C_EMPTY == activeColor); if (placedColor == pla) - setSpatial(pos, DotsSpatialFeature::PlayerPlaced); + setSpatial(pos, DotsSpatialFeature::PlayerPlaced_3); else if (placedColor == opp) - setSpatial(pos, DotsSpatialFeature::PlayerOppPlaced); + setSpatial(pos, DotsSpatialFeature::PlayerOppPlaced_4); else assert(C_EMPTY == placedColor); if (activeColor != C_EMPTY && placedColor != C_EMPTY && placedColor != activeColor) { // Needed for more correct score calculation, but probably it's redundant considering placed dots - setSpatial(pos, DotsSpatialFeature::DeadDots); + setSpatial(pos, DotsSpatialFeature::DeadDots_5); deadDotsCount++; } if (isGrounded(state)) { - setSpatial(pos, DotsSpatialFeature::Grounded); + setSpatial(pos, DotsSpatialFeature::Grounded_8); } const Color captureColor = captures[loc]; if ((pla & captureColor) != 0) { - setSpatial(pos, DotsSpatialFeature::PlayerCaptures); + setSpatial(pos, DotsSpatialFeature::PlayerCaptures_18); } if ((opp & captureColor) != 0) { - setSpatial(pos, DotsSpatialFeature::PlayerOppCaptures); + setSpatial(pos, DotsSpatialFeature::PlayerOppCaptures_19); } const Color baseColor = bases[loc]; if ((pla & baseColor) != 0) { - setSpatial(pos, DotsSpatialFeature::PlayerSurroundings); + setSpatial(pos, DotsSpatialFeature::PlayerSurroundings_20); } if ((opp & baseColor) != 0) { - setSpatial(pos, DotsSpatialFeature::PlayerOppSurroundings); + setSpatial(pos, DotsSpatialFeature::PlayerOppSurroundings_21); } // TODO: Set up history and ladder features @@ -110,34 +110,20 @@ void NNInputs::fillRowVDots( selfKomi = bArea+NNPos::KOMI_CLIP_RADIUS; if(selfKomi < -bArea-NNPos::KOMI_CLIP_RADIUS) selfKomi = -bArea-NNPos::KOMI_CLIP_RADIUS; - setGlobal(DotsGlobalFeature::Komi, selfKomi / NNPos::KOMI_CLIP_RADIUS); + setGlobal(DotsGlobalFeature::Komi_5, selfKomi / NNPos::KOMI_CLIP_RADIUS); if (rules.multiStoneSuicideLegal) { - setGlobal(DotsGlobalFeature::Suicide); + setGlobal(DotsGlobalFeature::Suicide_8); } if (rules.dotsCaptureEmptyBases) { - setGlobal(DotsGlobalFeature::CaptureEmpty); - } - - if (const int startPos = rules.startPos; startPos >= Rules::START_POS_CROSS) { - setGlobal(DotsGlobalFeature::StartPosCross); - if (startPos >= Rules::START_POS_CROSS_2) { - setGlobal(DotsGlobalFeature::StartPosCross2); - if (startPos >= Rules::START_POS_CROSS_4) { - setGlobal(DotsGlobalFeature::StartPosCross4); - } - } - } - - if (rules.startPosIsRandom) { - setGlobal(DotsGlobalFeature::StartPosIsRandom); + setGlobal(DotsGlobalFeature::CaptureEmpty_17); } if (hist.winOrEffectiveDrawByGrounding(board, pla)) { // Train to better understand grounding - setGlobal(DotsGlobalFeature::WinByGrounding); + setGlobal(DotsGlobalFeature::WinByGrounding_14); } - setGlobal(DotsGlobalFeature::FieldSizeKomiParity, 0.0f); // TODO: implement later + setGlobal(DotsGlobalFeature::FieldSizeKomiParity_18, 0.0f); // TODO: implement later } \ No newline at end of file diff --git a/cpp/neuralnet/nninterface.h b/cpp/neuralnet/nninterface.h index e5932a6d5..e44de9885 100644 --- a/cpp/neuralnet/nninterface.h +++ b/cpp/neuralnet/nninterface.h @@ -112,7 +112,8 @@ namespace NeuralNet { InputBuffers* buffers, int numBatchEltsFilled, NNResultBuf** inputBufs, - std::vector& outputs + const std::vector& outputs, + bool dotsGame ); diff --git a/cpp/neuralnet/openclbackend.cpp b/cpp/neuralnet/openclbackend.cpp index b8922bf2c..aee14e030 100644 --- a/cpp/neuralnet/openclbackend.cpp +++ b/cpp/neuralnet/openclbackend.cpp @@ -2499,12 +2499,12 @@ struct Model { numScoreValueChannels = desc->numScoreValueChannels; numOwnershipChannels = desc->numOwnershipChannels; - int numFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + int numFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, desc->isDotsGame); if(numInputChannels != numFeatures) throw StringError(Global::strprintf("Neural net numInputChannels (%d) was not the expected number based on version (%d)", numInputChannels, numFeatures )); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, desc->isDotsGame); if(numInputGlobalChannels != numGlobalFeatures) throw StringError(Global::strprintf("Neural net numInputGlobalChannels (%d) was not the expected number based on version (%d)", numInputGlobalChannels, numGlobalFeatures @@ -2890,8 +2890,8 @@ struct InputBuffers { singleScoreValueResultElts = (size_t)m.numScoreValueChannels; singleOwnershipResultElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; - assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); - assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion, m.isDotsGame) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion, m.isDotsGame) == m.numInputGlobalChannels); if(m.numInputMetaChannels > 0) { assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); } @@ -2952,14 +2952,13 @@ void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { delete inputBuffers; } - void NeuralNet::getOutput( ComputeHandle* gpuHandle, InputBuffers* inputBuffers, - int numBatchEltsFilled, + const int numBatchEltsFilled, NNResultBuf** inputBufs, - vector& outputs -) { + const vector& outputs, + const bool dotsGame) { assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); assert(numBatchEltsFilled > 0); const int batchSize = numBatchEltsFilled; @@ -2967,8 +2966,8 @@ void NeuralNet::getOutput( const int nnYLen = gpuHandle->nnYLen; const int modelVersion = gpuHandle->model->modelVersion; - const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); const int numMetaFeatures = inputBuffers->singleInputMetaElts; assert(numSpatialFeatures == gpuHandle->model->numInputChannels); assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); diff --git a/cpp/neuralnet/opencltuner.cpp b/cpp/neuralnet/opencltuner.cpp index 8c177a986..e629faadd 100644 --- a/cpp/neuralnet/opencltuner.cpp +++ b/cpp/neuralnet/opencltuner.cpp @@ -3350,7 +3350,7 @@ void OpenCLTuner::autoTuneEverything( string gpuName = allDeviceInfos[gpuIdxForTuning].name; //Just hardcodedly tune all the models that KataGo's main run uses. - static_assert(NNModelVersion::latestModelVersionImplemented == 17, ""); + static_assert(NNModelVersion::latestModelVersionImplemented == 16, ""); vector modelInfos; { ModelInfoForTuning modelInfo; @@ -3462,17 +3462,6 @@ void OpenCLTuner::autoTuneEverything( modelInfo.modelVersion = 16; modelInfos.push_back(modelInfo); } - { - ModelInfoForTuning modelInfo; - modelInfo.maxConvChannels1x1 = 512; - modelInfo.maxConvChannels3x3 = 512; - modelInfo.trunkNumChannels = 512; - modelInfo.midNumChannels = 256; - modelInfo.regularNumChannels = 192; - modelInfo.gpoolNumChannels = 64; - modelInfo.modelVersion = 17; - modelInfos.push_back(modelInfo); - } for(ModelInfoForTuning modelInfo : modelInfos) { int nnXLen = NNPos::MAX_BOARD_LEN_X; diff --git a/cpp/neuralnet/trtbackend.cpp b/cpp/neuralnet/trtbackend.cpp index c6df9f251..b8472a818 100644 --- a/cpp/neuralnet/trtbackend.cpp +++ b/cpp/neuralnet/trtbackend.cpp @@ -115,6 +115,7 @@ struct TRTModel { vector> extraWeights; int modelVersion; + bool dotsGame; uint8_t tuneHash[32]; IOptimizationProfile* profile; unique_ptr network; @@ -196,6 +197,7 @@ struct ModelParser { } model->modelVersion = modelDesc->modelVersion; + model->dotsGame = modelDesc->isDotsGame; network->setName(modelDesc->name.c_str()); initInputs(); @@ -247,13 +249,13 @@ struct ModelParser { int numInputGlobalChannels = modelDesc->numInputGlobalChannels; int numInputMetaChannels = modelDesc->numInputMetaChannels; - int numFeatures = NNModelVersion::getNumSpatialFeatures(model->modelVersion); + int numFeatures = NNModelVersion::getNumSpatialFeatures(model->modelVersion, model->dotsGame); if(numInputChannels != numFeatures) throw StringError(Global::strprintf( "Neural net numInputChannels (%d) was not the expected number based on version (%d)", numInputChannels, numFeatures)); - int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(model->modelVersion); + int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(model->modelVersion, model->dotsGame); if(numInputGlobalChannels != numGlobalFeatures) throw StringError(Global::strprintf( "Neural net numInputGlobalChannels (%d) was not the expected number based on version (%d)", @@ -1586,8 +1588,8 @@ struct InputBuffers { singleOwnershipResultElts = m.numOwnershipChannels * nnXLen * nnYLen; singleOwnershipResultBytes = singleOwnershipResultElts * sizeof(float); - assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); - assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion, m.isDotsGame) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion, m.isDotsGame) == m.numInputGlobalChannels); if(m.numInputMetaChannels > 0) { assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); } @@ -1631,7 +1633,8 @@ void NeuralNet::getOutput( InputBuffers* inputBuffers, int numBatchEltsFilled, NNResultBuf** inputBufs, - vector& outputs) { + const vector& outputs, + const bool dotsGame) { assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); assert(numBatchEltsFilled > 0); @@ -1640,8 +1643,8 @@ void NeuralNet::getOutput( const int nnYLen = gpuHandle->ctx->nnYLen; const int modelVersion = gpuHandle->modelVersion; - const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); - const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion, dotsGame); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion, dotsGame); const int numMetaFeatures = inputBuffers->singleInputMetaElts; assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts); diff --git a/cpp/tests/testnninputs.cpp b/cpp/tests/testnninputs.cpp index 118cec1e3..a1c7e7c79 100644 --- a/cpp/tests/testnninputs.cpp +++ b/cpp/tests/testnninputs.cpp @@ -14,7 +14,7 @@ static void printNNInputHWAndBoard( ostream& out, int inputsVersion, const Board& board, const BoardHistory& hist, int nnXLen, int nnYLen, bool inputsUseNHWC, T* row, int c ) { - const int numFeatures = NNInputs::getNumberOfSpatialFeatures(inputsVersion); + const int numFeatures = NNInputs::getNumberOfSpatialFeatures(inputsVersion, false); out << "Channel: " << c << endl; @@ -54,7 +54,7 @@ static void printNNInputHWAndBoard( template static void printNNInputGlobal(ostream& out, int inputsVersion, T* row, int c) { - const int numFeatures = NNInputs::getNumberOfGlobalFeatures(inputsVersion); + const int numFeatures = NNInputs::getNumberOfGlobalFeatures(inputsVersion, false); (void)numFeatures; out << "Channel: " << c; @@ -106,13 +106,13 @@ void Tests::runNNInputsV3V4Tests() { out << std::setprecision(5); auto allocateRows = [](int version, int nnXLen, int nnYLen, int& numFeaturesBin, int& numFeaturesGlobal, float*& rowBin, float*& rowGlobal) { - numFeaturesBin = NNInputs::getNumberOfSpatialFeatures(version); - numFeaturesGlobal = NNInputs::getNumberOfGlobalFeatures(version); + numFeaturesBin = NNInputs::getNumberOfSpatialFeatures(version, false); + numFeaturesGlobal = NNInputs::getNumberOfGlobalFeatures(version, false); rowBin = new float[numFeaturesBin * nnXLen * nnYLen]; rowGlobal = new float[numFeaturesGlobal]; }; - static_assert(NNModelVersion::latestInputsVersionImplemented == 8, ""); + static_assert(NNModelVersion::latestInputsVersionImplemented == 7, ""); int minVersion = 3; int maxVersion = 7; diff --git a/python/export_model_pytorch.py b/python/export_model_pytorch.py index c006085f2..7b642f4ef 100644 --- a/python/export_model_pytorch.py +++ b/python/export_model_pytorch.py @@ -83,7 +83,11 @@ def writestr(s): version = 15 writeln(model_name) - writeln(version) + writestr(str(version)) + writestr("," + str(model.pos_len_x)) + writestr("," + str(model.pos_len_y)) + writestr("," + ";".join(c.name for c in model.games)) + writeln("") writeln(modelconfigs.get_num_bin_input_features(model_config)) writeln(modelconfigs.get_num_global_input_features(model_config)) diff --git a/python/katago/train/load_model.py b/python/katago/train/load_model.py index cbf582c9a..66fd3cba9 100644 --- a/python/katago/train/load_model.py +++ b/python/katago/train/load_model.py @@ -9,6 +9,8 @@ from collections import defaultdict +from katago.train.model_pytorch import Game, parse_game + # defaultdict and float constructor are used in the ckpt for running metrics if packaging.version.parse(torch.__version__) > packaging.version.parse("2.4.0"): torch.serialization.add_safe_globals([defaultdict]) @@ -54,7 +56,10 @@ def load_model(checkpoint_file, use_swa, device, pos_len_x=19, pos_len_y=19, ver model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config, pos_len_x, pos_len_y) + effective_pos_len_x = state_dict.get("pos_len_x", pos_len_x) + effective_pos_len_y = state_dict.get("pos_len_y", pos_len_y) + games = [parse_game(s) for s in state_dict.get("games", [Game.GO.name])] + model = Model(model_config, effective_pos_len_x, effective_pos_len_y, games) model.initialize() # Strip off any "module." from when the model was saved with DDP or other things diff --git a/python/katago/train/model_pytorch.py b/python/katago/train/model_pytorch.py index 4654ac934..d233752ec 100644 --- a/python/katago/train/model_pytorch.py +++ b/python/katago/train/model_pytorch.py @@ -6,6 +6,8 @@ import torch.nn.init import packaging import packaging.version +from argparse import ArgumentTypeError +from enum import Enum from typing import List, Dict, Optional, Set from ..train import modelconfigs @@ -1603,8 +1605,21 @@ def forward(self, input_meta, extra_outputs: Optional[ExtraOutputs]): return self.out_scale * self.linear_output_to_trunk(x) +class Game(Enum): + GO = 0 + DOTS = 1 + +def parse_game(value: str) -> Game: + value = value.upper() + if value == "GO": + return Game.GO + elif value == "DOTS": + return Game.DOTS + else: + raise ArgumentTypeError(f"Game must be 'go' or 'dots', got '{value}'") + class Model(torch.nn.Module): - def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int): + def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: int, games=None): super(Model, self).__init__() self.config = config @@ -1623,6 +1638,7 @@ def __init__(self, config: modelconfigs.ModelConfig, pos_len_x: int, pos_len_y: self.num_total_blocks = len(self.block_kind) self.pos_len_x = pos_len_x self.pos_len_y = pos_len_y + self.games = games or [Game.GO] if config["version"] <= 12: self.td_score_multiplier = 20.0 diff --git a/python/katago/train/modelconfigs.py b/python/katago/train/modelconfigs.py index 976be7178..e24efdb8c 100644 --- a/python/katago/train/modelconfigs.py +++ b/python/katago/train/modelconfigs.py @@ -46,14 +46,14 @@ def get_version(config: ModelConfig): def get_num_bin_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: # Dots game uses the same number of spatial features return 22 else: assert(False) def get_num_global_input_features(config: ModelConfig): version = get_version(config) - if version == 10 or version == 11 or version == 12 or version == 13 or version == 14 or version == 15 or version == 16: + if 10 <= version <= 16: # Dots game uses the same number of global features return 19 else: assert(False) diff --git a/python/selfplay/train.sh b/python/selfplay/train.sh index 7e8da268e..79e92fcce 100755 --- a/python/selfplay/train.sh +++ b/python/selfplay/train.sh @@ -78,6 +78,7 @@ time python3 ./train.py \ -exportdir "$BASEDIR"/"$EXPORT_SUBDIR" \ -exportprefix "$TRAININGNAME" \ -pos-len 39 \ + -games DOTS \ -batch-size "$BATCHSIZE" \ -model-kind "$MODELKIND" \ $EXTRAFLAG \ diff --git a/python/test.py b/python/test.py index 929180529..9a3a93ecf 100644 --- a/python/test.py +++ b/python/test.py @@ -17,7 +17,7 @@ from torch.optim.swa_utils import AveragedModel from katago.train import modelconfigs -from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder +from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder, Game, parse_game from katago.train.metrics_pytorch import Metrics from katago.train import data_processing_pytorch from katago.train.load_model import load_model @@ -36,12 +36,15 @@ parser.add_argument('-config', help='Path to model.config.json', required=False) parser.add_argument('-checkpoint', help='Checkpoint to test', required=False) parser.add_argument('-pos-len', help='Spatial length of expected training data', type=int, required=True) + parser.add_argument('-pos-len-x', help='Spatial width of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + parser.add_argument('-pos-len-y', help='Spatial height of expected training data. If undefined, `-pos-len` is used', type=int, required=False) parser.add_argument('-batch-size', help='Batch size to use for testing', type=int, required=True) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-max-batches', help='Maximum number of batches for testing', type=int, required=False) parser.add_argument('-gpu-idx', help='GPU idx', type=int, required=False) parser.add_argument('-print-norm', help='Names of outputs to print norms comma separated', type=str, required=False) parser.add_argument('-list-available-outputs', help='Print names of outputs available', action="store_true", required=False) + parser.add_argument('-games', help='Games to train: ' + ', '.join(e.name for e in Game) + ' (GO by default)', type=parse_game, nargs="+", required=False) args = vars(parser.parse_args()) @@ -53,6 +56,7 @@ def main(args): pos_len = args["pos_len"] pos_len_x = args["pos_len_x"] or pos_len pos_len_y = args["pos_len_y"] or pos_len + games = args["games"] or [Game.GO] batch_size = args["batch_size"] use_swa = args["use_swa"] max_batches = args["max_batches"] @@ -111,7 +115,7 @@ def main(args): model_config = json.load(f) logging.info(str(model_config)) - model = Model(model_config,pos_len_x,pos_len_y) + model = Model(model_config,pos_len_x,pos_len_y,games) model.initialize() model.to(device) else: diff --git a/python/train.py b/python/train.py index 7a6850399..d612ffedf 100755 --- a/python/train.py +++ b/python/train.py @@ -31,7 +31,7 @@ from torch.cuda.amp import GradScaler, autocast from katago.train import modelconfigs -from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder +from katago.train.model_pytorch import Model, ExtraOutputs, MetadataEncoder, parse_game, Game from katago.train.metrics_pytorch import Metrics from katago.utils.push_back_generator import PushBackGenerator from katago.train import load_model @@ -68,6 +68,7 @@ required_args.add_argument('-pos-len', help='Spatial edge length of expected training data, e.g. 19 for 19x19 Go', type=int, required=True) optional_args.add_argument('-pos-len-x', help='Spatial width of expected training data. If undefined, `-pos-len` is used', type=int, required=False) optional_args.add_argument('-pos-len-y', help='Spatial height of expected training data. If undefined, `-pos-len` is used', type=int, required=False) + optional_args.add_argument('-games', help='Games to train: ' + ", ".join(e.name for e in Game) + ' (GO by default)', type=parse_game, nargs="+", required=False) optional_args.add_argument('-samples-per-epoch', help='Number of data samples to consider as one epoch', type=int, required=False) optional_args.add_argument('-model-kind', help='String name for what model config to use', required=False) optional_args.add_argument('-lr-scale', help='LR multiplier on the hardcoded schedule', type=float, required=False) @@ -157,6 +158,7 @@ def main(rank: int, world_size: int, args, multi_gpu_device_ids, readpipes, writ pos_len = args["pos_len"] pos_len_x = args["pos_len_x"] or pos_len pos_len_y = args["pos_len_y"] or pos_len + games = args["games"] or [Game.GO] batch_size = args["batch_size"] samples_per_epoch = args["samples_per_epoch"] model_kind = args["model_kind"] @@ -314,6 +316,9 @@ def save(ddp_model, swa_model, optimizer, metrics_obj, running_metrics, train_st if rank == 0: state_dict = {} state_dict["model"] = ddp_model.state_dict() + state_dict["pos_len_x"] = getattr(ddp_model, "pos_len_x", 19) + state_dict["pos_len_y"] = getattr(ddp_model, "pos_len_y", 19) + state_dict["games"] = [g.name for g in getattr(ddp_model, "games", [Game.GO])] state_dict["optimizer"] = optimizer.state_dict() state_dict["metrics"] = metrics_obj.state_dict() state_dict["running_metrics"] = running_metrics @@ -447,7 +452,7 @@ def load(): assert model_kind is not None, "Model kind is none or unspecified but the model is being created fresh" model_config = modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len_x,pos_len_y) + raw_model = Model(model_config,pos_len_x,pos_len_y,games) raw_model.initialize() raw_model.to(device) @@ -482,7 +487,7 @@ def load(): state_dict = torch.load(path_to_load_from, map_location=device) model_config = state_dict["config"] if "config" in state_dict else modelconfigs.config_of_name[model_kind] logging.info(str(model_config)) - raw_model = Model(model_config,pos_len_x,pos_len_y) + raw_model = Model(model_config,pos_len_x,pos_len_y,games) raw_model.initialize() train_state = {} From dbe65c1fa2448d8eed4ab892cc38750f51e0640e Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 16 Nov 2025 12:55:05 +0100 Subject: [PATCH 39/42] Refactor `getString` code in config_parser --- cpp/core/config_parser.cpp | 73 +++++++++++++++++++++----------------- cpp/core/config_parser.h | 13 ++++--- cpp/program/setup.cpp | 2 +- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/cpp/core/config_parser.cpp b/cpp/core/config_parser.cpp index 8cd36f60d..82fe34c9e 100644 --- a/cpp/core/config_parser.cpp +++ b/cpp/core/config_parser.cpp @@ -3,6 +3,7 @@ #include "../core/fileutils.h" #include +#include #include using namespace std; @@ -516,56 +517,62 @@ std::string ConfigParser::firstFoundOrEmpty(const std::vector& poss if(contains(key)) return key; } - return string(); + return {}; } +string ConfigParser::getString(const string& key, const set& possibles) { + const auto str = tryGetString(key); -string ConfigParser::getString(const string& key) { - auto iter = keyValues.find(key); - if(iter == keyValues.end()) + if (str == std::nullopt) throw IOError("Could not find key '" + key + "' in config file " + fileName); - { - std::lock_guard lock(usedKeysMutex); - usedKeys.insert(key); - } + auto value = str.value(); - return iter->second; -} + if (!possibles.empty()) { + if(possibles.find(value) == possibles.end()) + throw IOError("Key '" + key + "' must be one of (" + Global::concat(possibles,"|") + ") in config file " + fileName); + } -string ConfigParser::getString(const string& key, const set& possibles) { - string value = getString(key); - if(possibles.find(value) == possibles.end()) - throw IOError("Key '" + key + "' must be one of (" + Global::concat(possibles,"|") + ") in config file " + fileName); return value; } -vector ConfigParser::getStrings(const string& key) { - return Global::split(getString(key),','); -} +vector ConfigParser::getStrings(const string& key, const set& possibles, const bool nonEmptyTrim) { + vector values = Global::split(getString(key),','); -vector ConfigParser::getStringsNonEmptyTrim(const string& key) { - vector raw = Global::split(getString(key),','); - vector trimmed; - for(size_t i = 0; i trimmedStrings; + for(const auto& s : values) { + if (string trimmed = Global::trim(s); !trimmed.empty()) { + trimmedStrings.push_back(trimmed); + } + } + values = trimmedStrings; } - return trimmed; -} -vector ConfigParser::getStrings(const string& key, const set& possibles) { - vector values = getStrings(key); - for(size_t i = 0; i ConfigParser::tryGetString(const string& key) { + const auto iter = keyValues.find(key); + if(iter == keyValues.end()) { + return std::nullopt; + } + + { + std::lock_guard lock(usedKeysMutex); + usedKeys.insert(key); + } + + return iter->second; +} + bool ConfigParser::getBoolOrDefault(const std::string& key, const bool defaultValue) { if (contains(key)) { return getBool(key); diff --git a/cpp/core/config_parser.h b/cpp/core/config_parser.h index 5d381f433..fb9b4d118 100644 --- a/cpp/core/config_parser.h +++ b/cpp/core/config_parser.h @@ -2,10 +2,11 @@ #define CORE_CONFIG_PARSER_H_ #include +#include +#include "../core/commontypes.h" #include "../core/global.h" #include "../core/logger.h" -#include "../core/commontypes.h" /* Parses simple configs like: @@ -57,23 +58,21 @@ class ConfigParser { std::string firstFoundOrFail(const std::vector& possibleKeys) const; std::string firstFoundOrEmpty(const std::vector& possibleKeys) const; - std::string getString(const std::string& key); + std::string getString(const std::string& key, const std::set& possibles = {}); + std::vector getStrings(const std::string& key, const std::set& possibles = {}, bool nonEmptyTrim = false); + std::optional tryGetString(const std::string& key); + bool getBoolOrDefault(const std::string& key, bool defaultValue); bool getBool(const std::string& key); enabled_t getEnabled(const std::string& key); - std::string getString(const std::string& key, const std::set& possibles); int getInt(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); int64_t getInt64(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); uint64_t getUInt64(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); float getFloat(const std::string& key, float min = std::numeric_limits::min(), float max = std::numeric_limits::max()); double getDouble(const std::string& key, double min = std::numeric_limits::min(), double max = std::numeric_limits::max()); - std::vector getStrings(const std::string& key); - std::vector getStringsNonEmptyTrim(const std::string& key); std::vector getBools(const std::string& key); - - std::vector getStrings(const std::string& key, const std::set& possibles); std::vector getInts(const std::string& key, int min = std::numeric_limits::min(), int max = std::numeric_limits::max()); std::vector getInt64s(const std::string& key, int64_t min = std::numeric_limits::min(), int64_t max = std::numeric_limits::max()); std::vector getUInt64s(const std::string& key, uint64_t min = std::numeric_limits::min(), uint64_t max = std::numeric_limits::max()); diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 035a1a08e..9b9feae68 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -1034,7 +1034,7 @@ std::vector> Setup::loadAvoidSgfPatternBonusT double lambda = contains("PatternLambda") ? cfg.getDouble(find("PatternLambda"),0.0,1.0) : 1.0; int minTurnNumber = contains("PatternMinTurnNumber") ? cfg.getInt(find("PatternMinTurnNumber"),0,1000000) : 0; size_t maxFiles = contains("PatternMaxFiles") ? (size_t)cfg.getInt(find("PatternMaxFiles"),1,1000000) : 1000000; - vector allowedPlayerNames = contains("PatternAllowedNames") ? cfg.getStringsNonEmptyTrim(find("PatternAllowedNames")) : vector(); + vector allowedPlayerNames = contains("PatternAllowedNames") ? cfg.getStrings(find("PatternAllowedNames"), {}, true) : vector(); vector sgfDirs = cfg.getStrings(find("PatternDirs")); if(patternBonusTable == nullptr) patternBonusTable = std::make_unique(); From 86258e56551b3533a3f514f8312d2ffe4b716102 Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 16 Nov 2025 15:25:02 +0100 Subject: [PATCH 40/42] Fix compilation errors on Linux --- cpp/core/commontypes.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/core/commontypes.h b/cpp/core/commontypes.h index 8f1d57656..51441726f 100644 --- a/cpp/core/commontypes.h +++ b/cpp/core/commontypes.h @@ -1,6 +1,8 @@ #ifndef COMMONTYPES_H #define COMMONTYPES_H +#include + struct enabled_t { enum value { False, True, Auto }; value x; @@ -11,7 +13,7 @@ struct enabled_t { constexpr bool operator==(enabled_t a) const { return x == a.x; } constexpr bool operator!=(enabled_t a) const { return x != a.x; } - std::string toString() { + [[nodiscard]] std::string toString() const { return x == True ? "true" : x == False ? "false" : "auto"; } From 190f367bb493853ba590632beb4c829a93b15c2b Mon Sep 17 00:00:00 2001 From: Ivan Kochurkin Date: Sun, 16 Nov 2025 16:21:51 +0100 Subject: [PATCH 41/42] Fix compilation warnings from Linux and macOS compilers --- cpp/command/gtp.cpp | 4 +-- cpp/dataio/sgf.cpp | 2 +- cpp/dataio/trainingwrite.cpp | 12 ++++----- cpp/dataio/trainingwrite.h | 2 +- cpp/game/board.cpp | 8 +++--- cpp/game/board.h | 7 +++-- cpp/game/boardhistory.cpp | 10 +++---- cpp/game/boardhistory.h | 2 +- cpp/game/dotsboardhistory.cpp | 6 ++--- cpp/game/dotsfield.cpp | 17 ++++++------ cpp/game/graphhash.cpp | 2 +- cpp/game/rules.cpp | 39 ++++++++++++++------------- cpp/game/rules.h | 46 +++++++++++++++++--------------- cpp/neuralnet/modelversion.cpp | 4 +-- cpp/neuralnet/nneval.cpp | 6 ++--- cpp/neuralnet/nneval.h | 2 +- cpp/neuralnet/nninputs.cpp | 4 +-- cpp/neuralnet/nninputsdots.cpp | 2 +- cpp/search/search.cpp | 6 ++--- cpp/search/search.h | 2 +- cpp/tests/testdotsbasic.cpp | 2 +- cpp/tests/testdotsextra.cpp | 10 +++---- cpp/tests/testdotsstartposes.cpp | 10 +++---- cpp/tests/testdotsstress.cpp | 18 ++++++------- cpp/tests/testdotsutils.h | 4 +-- cpp/tests/testnninputs.cpp | 23 ++++++++++------ cpp/tests/testtrainingwrite.cpp | 2 +- 27 files changed, 132 insertions(+), 120 deletions(-) diff --git a/cpp/command/gtp.cpp b/cpp/command/gtp.cpp index 99e56632c..87ea9d19b 100644 --- a/cpp/command/gtp.cpp +++ b/cpp/command/gtp.cpp @@ -1910,7 +1910,7 @@ static GTPEngine::AnalyzeArgs parseAnalyzeCommand( return args; } -optional parseMovesSequence(const vector& pieces, const Board& board, bool passIsAllowed, vector& movesToPlay) { +static optional parseMovesSequence(const vector& pieces, const Board& board, bool passIsAllowed, vector& movesToPlay) { optional response = std::nullopt; auto renderPieces = [pieces](const int& index) { @@ -1948,7 +1948,7 @@ optional parseMovesSequence(const vector& pieces, const Boa return response; } -string printMoves(const vector& moves, const Board& board) { +static string printMoves(const vector& moves, const Board& board) { std::ostringstream builder; for (const auto move : moves) { builder << PlayerIO::playerToStringShort(move.pla, board.isDots()) << " " << Location::toString(move.loc, board) << " "; diff --git a/cpp/dataio/sgf.cpp b/cpp/dataio/sgf.cpp index 7216d4696..433519565 100644 --- a/cpp/dataio/sgf.cpp +++ b/cpp/dataio/sgf.cpp @@ -1636,8 +1636,8 @@ std::vector> Sgf::loadSgfOrSgfsLogAndIgnoreErrors(const str CompactSgf::CompactSgf(const Sgf& sgf) :fileName(sgf.fileName), - rootNode(), isDots(), + rootNode(), placements(), moves(), xSize(), diff --git a/cpp/dataio/trainingwrite.cpp b/cpp/dataio/trainingwrite.cpp index 5f26c9411..a9b0db48a 100644 --- a/cpp/dataio/trainingwrite.cpp +++ b/cpp/dataio/trainingwrite.cpp @@ -953,8 +953,8 @@ void TrainingWriteBuffers::writeToTextOstream(ostream& out) { //------------------------------------------------------------------------------------- -TrainingDataWriter::TrainingDataWriter(const string& outputDir, ostream* debugOut, - const int inputsVersion, +TrainingDataWriter::TrainingDataWriter(const string& newOutputDir, ostream* newDebugOut, + const int newInputsVersion, const int maxRowsPerFile, const double firstFileMinRandProp, const int dataXLen, @@ -962,16 +962,16 @@ TrainingDataWriter::TrainingDataWriter(const string& outputDir, ostream* debugOu const string& randSeed, const int onlyWriteEvery, const bool dotsGame) - :outputDir(outputDir),inputsVersion(inputsVersion),rand(randSeed),writeBuffers(nullptr),debugOut(debugOut),debugOnlyWriteEvery(onlyWriteEvery),rowCount(0) + :outputDir(newOutputDir),inputsVersion(newInputsVersion),rand(randSeed),writeBuffers(nullptr),debugOut(newDebugOut),debugOnlyWriteEvery(onlyWriteEvery),rowCount(0) { //Note that this inputsVersion is for data writing, it might be different than the inputsVersion used // to feed into a model during selfplay - const int numBinaryChannels = NNInputs::getNumberOfSpatialFeatures(inputsVersion, dotsGame); - const int numGlobalChannels = NNInputs::getNumberOfGlobalFeatures(inputsVersion, dotsGame); + const int numBinaryChannels = NNInputs::getNumberOfSpatialFeatures(newInputsVersion, dotsGame); + const int numGlobalChannels = NNInputs::getNumberOfGlobalFeatures(newInputsVersion, dotsGame); constexpr bool hasMetadataInput = false; writeBuffers = new TrainingWriteBuffers( - inputsVersion, + newInputsVersion, maxRowsPerFile, numBinaryChannels, numGlobalChannels, diff --git a/cpp/dataio/trainingwrite.h b/cpp/dataio/trainingwrite.h index ffd0b32b5..fcf754cbe 100644 --- a/cpp/dataio/trainingwrite.h +++ b/cpp/dataio/trainingwrite.h @@ -311,7 +311,7 @@ struct TrainingWriteBuffers { class TrainingDataWriter { public: - TrainingDataWriter(const std::string& outputDir, std::ostream* debugOut, int inputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1, bool dotsGame = false); + TrainingDataWriter(const std::string& newOutputDir, std::ostream* newDebugOut, int newInputsVersion, int maxRowsPerFile, double firstFileMinRandProp, int dataXLen, int dataYLen, const std::string& randSeed, int onlyWriteEvery = 1, bool dotsGame = false); ~TrainingDataWriter(); void writeGame(const FinishedGameData& data); diff --git a/cpp/game/board.cpp b/cpp/game/board.cpp index 55ba016b7..05b2bb3b4 100644 --- a/cpp/game/board.cpp +++ b/cpp/game/board.cpp @@ -119,12 +119,12 @@ Board::Base::Base(Player newPla, Board::Board() : Board(Rules::DEFAULT_GO) {} -Board::Board(const Rules& rules) { - init(rules.isDots ? DEFAULT_LEN_X_DOTS : DEFAULT_LEN_X, rules.isDots ? DEFAULT_LEN_Y_DOTS : DEFAULT_LEN_Y, rules); +Board::Board(const Rules& newRules) { + init(newRules.isDots ? DEFAULT_LEN_X_DOTS : DEFAULT_LEN_X, newRules.isDots ? DEFAULT_LEN_Y_DOTS : DEFAULT_LEN_Y, newRules); } -Board::Board(const int x, const int y, const Rules& rules) { - init(x, y, rules); +Board::Board(const int x, const int y, const Rules& newRules) { + init(x, y, newRules); } Board::Board(const Board& other) { diff --git a/cpp/game/board.h b/cpp/game/board.h index 311b0735e..fccdfc162 100644 --- a/cpp/game/board.h +++ b/cpp/game/board.h @@ -37,7 +37,7 @@ typedef int8_t State; static constexpr int PLAYER_BITS_COUNT = 2; static constexpr State ACTIVE_MASK = (1 << PLAYER_BITS_COUNT) - 1; -static Color getOpp(Color c) +static inline Color getOpp(Color c) {return c ^ 3;} Color getActiveColor(State state); @@ -58,7 +58,6 @@ namespace PlayerIO { std::string playerToString(Player p, bool isDots); bool tryParsePlayer(const std::string& s, Player& pla); Player parsePlayer(const std::string& s); - char stateToChar(State s, bool isDots); } namespace Location @@ -236,8 +235,8 @@ struct Board //Constructors--------------------------------- Board(); //Create Board of size (DEFAULT_LEN,DEFAULT_LEN) - explicit Board(const Rules& rules); - Board(int x, int y, const Rules& rules); // Create Board of size (x,y) with the specified Rules + explicit Board(const Rules& newRules); + Board(int x, int y, const Rules& newRules); // Create Board of size (x,y) with the specified Rules Board(const Board& other); Board& operator=(const Board&) = default; diff --git a/cpp/game/boardhistory.cpp b/cpp/game/boardhistory.cpp index fd4ea6dc8..6a638f071 100644 --- a/cpp/game/boardhistory.cpp +++ b/cpp/game/boardhistory.cpp @@ -25,13 +25,13 @@ static Hash128 getKoHashAfterMoveNonEncore(const Rules& rules, Hash128 posHashAf BoardHistory::BoardHistory() : BoardHistory(Rules::DEFAULT_GO) {} -BoardHistory::BoardHistory(const Rules& rules) - : rules(rules), +BoardHistory::BoardHistory(const Rules& newRules) + : rules(newRules), moveHistory(), preventEncoreHistory(), koHashHistory(), firstTurnIdxWithKoHistory(0), - initialBoard(rules), + initialBoard(newRules), initialPla(P_BLACK), initialEncorePhase(0), initialTurnNumber(0), @@ -61,9 +61,9 @@ BoardHistory::BoardHistory(const Rules& rules) isResignation(false), isPassAliveFinished(false) { for(int i = 0; i < NUM_RECENT_BOARDS; i++) { - recentBoards.emplace_back(rules); + recentBoards.emplace_back(newRules); } - if(!rules.isDots) { + if(!newRules.isDots) { wasEverOccupiedOrPlayed.resize(Board::MAX_ARR_SIZE, false); superKoBanned.resize(Board::MAX_ARR_SIZE, false); koRecapBlocked.resize(Board::MAX_ARR_SIZE, false); diff --git a/cpp/game/boardhistory.h b/cpp/game/boardhistory.h index 103bb94da..cf1ba4241 100644 --- a/cpp/game/boardhistory.h +++ b/cpp/game/boardhistory.h @@ -104,7 +104,7 @@ struct BoardHistory { bool isPassAliveFinished; BoardHistory(); - explicit BoardHistory(const Rules& rules); + explicit BoardHistory(const Rules& newRules); ~BoardHistory(); BoardHistory(const Board& board); diff --git a/cpp/game/dotsboardhistory.cpp b/cpp/game/dotsboardhistory.cpp index aed74dec0..3736e7f8e 100644 --- a/cpp/game/dotsboardhistory.cpp +++ b/cpp/game/dotsboardhistory.cpp @@ -19,9 +19,9 @@ bool BoardHistory::winOrEffectiveDrawByGrounding(const Board& board, const Playe assert(rules.isDots); const float whiteScore = whiteScoreIfGroundingAlive(board); - return considerDraw && Global::isZero(whiteScore) || - pla == P_WHITE && whiteScore > Global::FLOAT_EPS || - pla == P_BLACK && whiteScore < -Global::FLOAT_EPS; + return (considerDraw && Global::isZero(whiteScore)) || + (pla == P_WHITE && whiteScore > Global::FLOAT_EPS) || + (pla == P_BLACK && whiteScore < -Global::FLOAT_EPS); } float BoardHistory::whiteScoreIfGroundingAlive(const Board& board) const { diff --git a/cpp/game/dotsfield.cpp b/cpp/game/dotsfield.cpp index 7062de431..544fa3b5b 100644 --- a/cpp/game/dotsfield.cpp +++ b/cpp/game/dotsfield.cpp @@ -98,23 +98,24 @@ Color getPlacedDotColor(const State s) { return static_cast(s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK); } -bool isPlaced(const State s, const Player pla) { +static bool isPlaced(const State s, const Player pla) { return (s >> PLACED_PLAYER_SHIFT & ACTIVE_MASK) == pla; } -bool isActive(const State s, const Player pla) { +static bool isActive(const State s, const Player pla) { return (s & ACTIVE_MASK) == pla; } -State setTerritoryAndActivePlayer(const State s, const Player pla) { - return static_cast(TERRITORY_FLAG | (s & INVALIDATE_TERRITORY_MASK | pla)); +static State setTerritoryAndActivePlayer(const State s, const Player pla) { + const int invalidateTerritoryValue = s & INVALIDATE_TERRITORY_MASK; + return static_cast(TERRITORY_FLAG | invalidateTerritoryValue | pla); } Color getEmptyTerritoryColor(const State s) { return static_cast(s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK); } -bool isWithinEmptyTerritory(const State s, const Player pla) { +static bool isWithinEmptyTerritory(const State s, const Player pla) { return (s >> EMPTY_TERRITORY_SHIFT & ACTIVE_MASK) == pla; } @@ -134,7 +135,7 @@ bool isGrounded(const State state) { return (state & GROUNDED_FLAG) == GROUNDED_FLAG; } -bool isGroundedOrWall(const State state, const Player pla) { +static bool isGroundedOrWall(const State state, const Player pla) { // Use bit tricks for grounding detecting. // If the active player is C_WALL, then the result is also true. return (state & GROUNDED_FLAG) == GROUNDED_FLAG && (state & pla) == pla; @@ -167,8 +168,8 @@ void Board::clearVisited(const vector& locations) const { } int Board::calculateOwnershipAndWhiteScore(Color* result, const Color groundingPlayer) const { - int whiteCaptures = 0; - int blackCaptures = 0; + [[maybe_unused]] int whiteCaptures = 0; + [[maybe_unused]] int blackCaptures = 0; for (int y = 0; y < y_size; y++) { for (int x = 0; x < x_size; x++) { diff --git a/cpp/game/graphhash.cpp b/cpp/game/graphhash.cpp index 988cf433e..32add1ffd 100644 --- a/cpp/game/graphhash.cpp +++ b/cpp/game/graphhash.cpp @@ -49,7 +49,7 @@ Hash128 GraphHash::getGraphHashFromScratch(const BoardHistory& histOrig, Player const Move move = histOrig.moveHistory[i]; graphHash = getGraphHash(graphHash, hist, move.pla, repBound, drawEquivalentWinsForWhite); const bool preventEncoreHistory = hist.rules.isDots ? false : histOrig.preventEncoreHistory[i]; - const bool suc = hist.makeBoardMoveTolerant(board, move.loc, move.pla, preventEncoreHistory); + [[maybe_unused]] const bool suc = hist.makeBoardMoveTolerant(board, move.loc, move.pla, preventEncoreHistory); assert(suc); } assert( diff --git a/cpp/game/rules.cpp b/cpp/game/rules.cpp index dbd98a43a..9ae0b52de 100644 --- a/cpp/game/rules.cpp +++ b/cpp/game/rules.cpp @@ -15,8 +15,8 @@ const Rules Rules::DEFAULT_GO = Rules(false); Rules::Rules() : Rules(false) {} -Rules::Rules(const int startPos, const bool startPosIsRandom, const bool suicide, const bool dotsCaptureEmptyBases, const bool dotsFreeCapturedDots) : - Rules(true, startPos, startPosIsRandom, 0, 0, 0, suicide, false, 0, false, 0.0f, dotsCaptureEmptyBases, dotsFreeCapturedDots) {} +Rules::Rules(const int newStartPos, const bool newStartPosIsRandom, const bool newSuicide, const bool newDotsCaptureEmptyBases, const bool newDotsFreeCapturedDots) : + Rules(true, newStartPos, newStartPosIsRandom, 0, 0, 0, newSuicide, false, 0, false, 0.0f, newDotsCaptureEmptyBases, newDotsFreeCapturedDots) {} Rules::Rules( int kRule, @@ -48,9 +48,9 @@ Rules::Rules(const bool initIsDots) : Rules( } Rules::Rules( - bool isDots, + const bool newIsDots, int startPosRule, - bool startPosIsRandom, + const bool newStartPosIsRandom, int kRule, int sRule, int tRule, @@ -59,22 +59,25 @@ Rules::Rules( int whbRule, bool pOk, float km, - bool dotsCaptureEmptyBases, - bool dotsFreeCapturedDots + const bool newDotsCaptureEmptyBases, + const bool newDotsFreeCapturedDots ) - : isDots(isDots), + : isDots(newIsDots), + startPos(startPosRule), - startPosIsRandom(startPosIsRandom), - dotsCaptureEmptyBases(dotsCaptureEmptyBases), - dotsFreeCapturedDots(dotsFreeCapturedDots), + startPosIsRandom(newStartPosIsRandom), + komi(km), + multiStoneSuicideLegal(suic), + taxRule(tRule), + whiteHandicapBonusRule(whbRule), + koRule(kRule), scoringRule(sRule), - taxRule(tRule), - multiStoneSuicideLegal(suic), hasButton(button), - whiteHandicapBonusRule(whbRule), friendlyPassOk(pOk), - komi(km) + + dotsCaptureEmptyBases(newDotsCaptureEmptyBases), + dotsFreeCapturedDots(newDotsFreeCapturedDots) { initializeIfNeeded(); } @@ -727,15 +730,15 @@ string Rules::toStringNoSgfDefinedPropertiesMaybeNice() const { return toString(false); } -double nextRandomOffset(Rand& rand) { +static double nextRandomOffset(Rand& rand) { return rand.nextDouble(4, 7); } -int nextRandomOffsetX(Rand& rand, int x_size) { +static int nextRandomOffsetX(Rand& rand, int x_size) { return static_cast(round(nextRandomOffset(rand) / 39.0 * x_size)); } -int nextRandomOffsetY(Rand& rand, int y_size) { +static int nextRandomOffsetY(Rand& rand, int y_size) { return static_cast(round(nextRandomOffset(rand) / 32.0 * y_size)); } @@ -955,7 +958,7 @@ int Rules::recognizeStartPos( if (remainingMoves != nullptr) { for (const auto recognizedMove : startPosMoves) { - bool remainingMoveIsRemoved = false; + [[maybe_unused]] bool remainingMoveIsRemoved = false; for(auto it = remainingMoves->begin(); it != remainingMoves->end(); ++it) { if (movesEqual(*it, recognizedMove)) { remainingMoves->erase(it); diff --git a/cpp/game/rules.h b/cpp/game/rules.h index 6674cf34b..343f2b936 100644 --- a/cpp/game/rules.h +++ b/cpp/game/rules.h @@ -12,6 +12,8 @@ struct Rules { const static Rules DEFAULT_DOTS; const static Rules DEFAULT_GO; + bool isDots; + static constexpr int START_POS_EMPTY = 0; static constexpr int START_POS_CROSS = 1; static constexpr int START_POS_CROSS_2 = 2; @@ -21,15 +23,11 @@ struct Rules { // Enables random shuffling of start pos. Currently, it works only for CROSS_4 bool startPosIsRandom; - static const int KO_SIMPLE = 0; - static const int KO_POSITIONAL = 1; - static const int KO_SITUATIONAL = 2; - static const int KO_SPIGHT = 3; - int koRule; - - static const int SCORING_AREA = 0; - static const int SCORING_TERRITORY = 1; - int scoringRule; + float komi; + //Min and max acceptable komi in various places involving user input validation + static constexpr float MIN_USER_KOMI = -150.0f; + static constexpr float MAX_USER_KOMI = 150.0f; + bool multiStoneSuicideLegal; // Works as just suicide in Dots Game static const int TAX_NONE = 0; static const int TAX_SEKI = 1; @@ -41,25 +39,29 @@ struct Rules { static const int WHB_N_MINUS_ONE = 2; int whiteHandicapBonusRule; - float komi; - //Min and max acceptable komi in various places involving user input validation - static constexpr float MIN_USER_KOMI = -150.0f; - static constexpr float MAX_USER_KOMI = 150.0f; + static const int KO_SIMPLE = 0; + static const int KO_POSITIONAL = 1; + static const int KO_SITUATIONAL = 2; + static const int KO_SPIGHT = 3; + int koRule; - bool isDots; + static const int SCORING_AREA = 0; + static const int SCORING_TERRITORY = 1; + int scoringRule; - bool dotsCaptureEmptyBases; - bool dotsFreeCapturedDots; // TODO: Implement later - bool multiStoneSuicideLegal; // Works as just suicide in Dots Game bool hasButton; + //Mostly an informational value - doesn't affect the actual implemented rules, but GTP or Analysis may, at a //high level, use this info to adjust passing behavior - whether it's okay to pass without capturing dead stones. //Only relevant for area scoring. bool friendlyPassOk; + bool dotsCaptureEmptyBases; + bool dotsFreeCapturedDots; // TODO: Implement later + Rules(); // Constructor for Dots - Rules(int startPos, bool startPosIsRandom, bool suicide, bool dotsCaptureEmptyBases, bool dotsFreeCapturedDots); + Rules(int newStartPos, bool newStartPosIsRandom, bool newSuicide, bool newDotsCaptureEmptyBases, bool newDotsFreeCapturedDots); // Constructor for Go Rules( int koRule, @@ -154,9 +156,9 @@ struct Rules { private: // General constructor Rules( - bool isDots, + bool newIsDots, int startPosRule, - bool startPosIsRandom, + bool newStartPosIsRandom, int kRule, int sRule, int tRule, @@ -165,8 +167,8 @@ struct Rules { int whbRule, bool pOk, float km, - bool dotsCaptureEmptyBases, - bool dotsFreeCapturedDots + bool newDotsCaptureEmptyBases, + bool newDotsFreeCapturedDots ); static inline std::map startPosIdToName; diff --git a/cpp/neuralnet/modelversion.cpp b/cpp/neuralnet/modelversion.cpp index 05e54f4a6..4770da969 100644 --- a/cpp/neuralnet/modelversion.cpp +++ b/cpp/neuralnet/modelversion.cpp @@ -91,7 +91,7 @@ int NNModelVersion::getNumSpatialFeatures(const int modelVersion, const bool dot break; default: if (modelVersion <= latestModelVersionImplemented) { - return dotsGame ? NNInputs::NUM_FEATURES_SPATIAL_V7_DOTS : NNInputs::NUM_FEATURES_SPATIAL_V7; + return NNInputs::NUM_FEATURES_SPATIAL_V7; // Use NUM_FEATURES_SPATIAL_V7_DOTS if it's value is changed } break; } @@ -125,7 +125,7 @@ int NNModelVersion::getNumGlobalFeatures(const int modelVersion, const bool dots break; default: if (modelVersion <= latestModelVersionImplemented) { - return dotsGame ? NNInputs::NUM_FEATURES_GLOBAL_V7_DOTS : NNInputs::NUM_FEATURES_GLOBAL_V7; + return NNInputs::NUM_FEATURES_GLOBAL_V7; // Use NUM_FEATURES_GLOBAL_V7_DOTS if it's value is changed } break; } diff --git a/cpp/neuralnet/nneval.cpp b/cpp/neuralnet/nneval.cpp index b5bf2ccad..6240f917c 100644 --- a/cpp/neuralnet/nneval.cpp +++ b/cpp/neuralnet/nneval.cpp @@ -67,7 +67,7 @@ NNEvaluator::NNEvaluator( const string& rSeed, bool doRandomize, int defaultSymmetry, - bool dotsGame + const bool newDotsGame ) :modelName(mName), modelFileName(mFileName), @@ -108,7 +108,7 @@ NNEvaluator::NNEvaluator( currentDefaultSymmetry(defaultSymmetry), currentBatchSize(maxBatchSz), queryQueue(), - dotsGame(dotsGame) + dotsGame(newDotsGame) { if(nnXLen > NNPos::MAX_BOARD_LEN_X) throw StringError("Maximum supported nnEval board size x is " + Global::intToString(NNPos::MAX_BOARD_LEN_X)); @@ -151,7 +151,7 @@ NNEvaluator::NNEvaluator( internalModelName = "random"; modelVersion = NNModelVersion::defaultModelVersion; } - inputsVersion = NNModelVersion::getInputsVersion(modelVersion, dotsGame); + inputsVersion = NNModelVersion::getInputsVersion(modelVersion, newDotsGame); //Reserve a decent amount above the batch size so that allocation is unlikely. queryQueue.reserve(maxBatchSize * 4 * gpuIdxByServerThread.size()); diff --git a/cpp/neuralnet/nneval.h b/cpp/neuralnet/nneval.h index a13969c63..81f298960 100644 --- a/cpp/neuralnet/nneval.h +++ b/cpp/neuralnet/nneval.h @@ -101,7 +101,7 @@ class NNEvaluator { const std::string& randSeed, bool doRandomize, int defaultSymmetry, - bool dotsGame + bool newDotsGame ); ~NNEvaluator(); diff --git a/cpp/neuralnet/nninputs.cpp b/cpp/neuralnet/nninputs.cpp index 0e5db110b..0313b5b8b 100644 --- a/cpp/neuralnet/nninputs.cpp +++ b/cpp/neuralnet/nninputs.cpp @@ -1067,7 +1067,7 @@ int NNInputs::getNumberOfSpatialFeatures(const int version, const bool isDots) { case 4: if (!isDots) return NUM_FEATURES_SPATIAL_V4; break; case 5: if (!isDots) return NUM_FEATURES_SPATIAL_V5; break; case 6: if (!isDots) return NUM_FEATURES_SPATIAL_V6; break; - case 7: return isDots ? NUM_FEATURES_SPATIAL_V7_DOTS : NUM_FEATURES_SPATIAL_V7; + case 7: return NUM_FEATURES_SPATIAL_V7; // Use NUM_FEATURES_SPATIAL_V7_DOTS if it's value is changed default: break; } throw std::range_error("Invalid input version: " + to_string(version) + (isDots ? " (Dots game)" : "")); @@ -1079,7 +1079,7 @@ int NNInputs::getNumberOfGlobalFeatures(const int version, const bool isDots) { case 4: if (!isDots) return NUM_FEATURES_GLOBAL_V4; break; case 5: if (!isDots) return NUM_FEATURES_GLOBAL_V5; break; case 6: if (!isDots) return NUM_FEATURES_GLOBAL_V6; break; - case 7: return isDots ? NUM_FEATURES_GLOBAL_V7_DOTS : NUM_FEATURES_GLOBAL_V7; + case 7: return NUM_FEATURES_GLOBAL_V7; // Use NUM_FEATURES_GLOBAL_V7_DOTS if it's value is changed default: break; } throw std::range_error("Invalid input version: " + to_string(version) + (isDots ? " (Dots game)" : "")); diff --git a/cpp/neuralnet/nninputsdots.cpp b/cpp/neuralnet/nninputsdots.cpp index b08e173e1..7b1d56b27 100644 --- a/cpp/neuralnet/nninputsdots.cpp +++ b/cpp/neuralnet/nninputsdots.cpp @@ -35,7 +35,7 @@ void NNInputs::fillRowV7Dots( vector captures; vector bases; board.calculateOneMoveCaptureAndBasePositionsForDots(captures, bases); - int deadDotsCount = 0; + [[maybe_unused]] int deadDotsCount = 0; auto setSpatial = [&](const int pos, const DotsSpatialFeature spatialFeature) { setRowBin(rowBin, pos, static_cast(spatialFeature), 1.0f, posStride, featureStride); diff --git a/cpp/search/search.cpp b/cpp/search/search.cpp index a85147b95..d5bcfef77 100644 --- a/cpp/search/search.cpp +++ b/cpp/search/search.cpp @@ -65,7 +65,7 @@ SearchThread::~SearchThread() { static const double VALUE_WEIGHT_DEGREES_OF_FREEDOM = 3.0; -Search::Search(const SearchParams ¶ms, NNEvaluator* nnEval, Logger* lg, const string& randSeed, NNEvaluator* humanEval, const Rules& rules) +Search::Search(const SearchParams ¶ms, NNEvaluator* nnEval, Logger* lg, const string& newRandSeed, NNEvaluator* humanEval, const Rules& rules) : rootPla(P_BLACK), rootBoard(rules), rootHistory(rules), @@ -84,14 +84,14 @@ Search::Search(const SearchParams ¶ms, NNEvaluator* nnEval, Logger* lg, cons plaThatSearchIsFor(C_EMPTY), plaThatSearchIsForLastSearch(C_EMPTY), lastSearchNumPlayouts(0), effectiveSearchTimeCarriedOver(0.0), - randSeed(randSeed), + randSeed(newRandSeed), rootKoHashTable(NULL), valueWeightDistribution(NULL), normToTApproxZ(0), patternBonusTable(NULL), externalPatternBonusTable(nullptr), evalCache(nullptr), - nonSearchRand(randSeed + string("$nonSearchRand")), + nonSearchRand(newRandSeed + string("$nonSearchRand")), logger(lg), nnEvaluator(nnEval), humanEvaluator(humanEval), diff --git a/cpp/search/search.h b/cpp/search/search.h index f59511c34..431dfadfc 100644 --- a/cpp/search/search.h +++ b/cpp/search/search.h @@ -185,7 +185,7 @@ struct Search { const SearchParams ¶ms, NNEvaluator* nnEval, Logger* logger, - const std::string& randSeed, + const std::string& newRandSeed, NNEvaluator* humanEval = nullptr, const Rules& rules = Rules::DEFAULT_GO); ~Search(); diff --git a/cpp/tests/testdotsbasic.cpp b/cpp/tests/testdotsbasic.cpp index fd4e6ebee..50553c4dd 100644 --- a/cpp/tests/testdotsbasic.cpp +++ b/cpp/tests/testdotsbasic.cpp @@ -7,7 +7,7 @@ using namespace std; using namespace TestCommon; -void checkDotsField(const string& description, const string& input, +static void checkDotsField(const string& description, const string& input, const std::function& check, const bool suicide = Rules::DEFAULT_DOTS.multiStoneSuicideLegal, const bool captureEmptyBases = Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, diff --git a/cpp/tests/testdotsextra.cpp b/cpp/tests/testdotsextra.cpp index aa6cf63ad..2be862836 100644 --- a/cpp/tests/testdotsextra.cpp +++ b/cpp/tests/testdotsextra.cpp @@ -7,7 +7,7 @@ using namespace std; using namespace TestCommon; -void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardInput, const vector& extraMoves, const int symmetry) { +static void checkSymmetry(const Board& initBoard, const string& expectedSymmetryBoardInput, const vector& extraMoves, const int symmetry) { const Board transformedBoard = SymmetryHelpers::getSymBoard(initBoard, symmetry); Board expectedBoard = parseDotsFieldDefault(expectedSymmetryBoardInput); for (const XYMove& extraMove : extraMoves) { @@ -131,7 +131,7 @@ SymmetryHelpers::SYMMETRY_TRANSPOSE_FLIP_Y_X); testAssert(board.isEqualForTesting(unrotatedBoard)); } -string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { +static string getOwnership(const string& boardData, const Color groundingPlayer, const int expectedWhiteScore, const vector& extraMoves) { const Board board = parseDotsFieldDefault(boardData, extraMoves); Color result[Board::MAX_ARR_SIZE]; @@ -151,7 +151,7 @@ string getOwnership(const string& boardData, const Color groundingPlayer, const return oss.str(); } -void expect( +static void expect( const char* name, const Color groundingPlayer, const std::string& actualField, @@ -250,7 +250,7 @@ R"( )", 1, {XYMove(0, 2, P_WHITE)}); } -std::pair getCapturingAndBases( +static std::pair getCapturingAndBases( const string& boardData, const bool suicide, const bool captureEmptyBases, @@ -299,7 +299,7 @@ std::pair getCapturingAndBases( return {capturesStringStream.str(), basesStringStream.str()}; } -void checkCapturingAndBase( +static void checkCapturingAndBase( const string& title, const string& boardData, const string& expectedCaptures, diff --git a/cpp/tests/testdotsstartposes.cpp b/cpp/tests/testdotsstartposes.cpp index c44c02214..84c0553c4 100644 --- a/cpp/tests/testdotsstartposes.cpp +++ b/cpp/tests/testdotsstartposes.cpp @@ -10,7 +10,7 @@ using namespace std; using namespace std::chrono; using namespace TestCommon; -void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startPosIsRandom, const Board& board) { +static void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startPosIsRandom, const Board& board) { std::ostringstream sgfStringStream; const BoardHistory boardHistory(board, P_BLACK, board.rules, 0); WriteSgf::writeSgf(sgfStringStream, "black", "white", boardHistory, {}); @@ -23,7 +23,7 @@ void writeToSgfAndCheckStartPosFromSgfProp(const int startPos, const bool startP testAssert(startPosIsRandom == newRules.startPosIsRandom); } -void checkStartPos(const string& description, const int startPos, const bool startPosIsRandom, const int x_size, const int y_size, const string& expectedBoard = "", const vector& extraMoves = {}) { +static void checkStartPos(const string& description, const int startPos, const bool startPosIsRandom, const int x_size, const int y_size, const string& expectedBoard = "", const vector& extraMoves = {}) { cout << " " << description << " (" << to_string(x_size) << "," << to_string(y_size) << ")"; auto board = Board(x_size, y_size, Rules(startPos, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots)); @@ -40,7 +40,7 @@ void checkStartPos(const string& description, const int startPos, const bool sta writeToSgfAndCheckStartPosFromSgfProp(startPos, startPosIsRandom, board); } -void checkRecognition(const vector& xyMoves, const int x_size, const int y_size, +static void checkRecognition(const vector& xyMoves, const int x_size, const int y_size, const int expectedStartPos, const vector& expectedStartMoves, const bool expectedRandomized, @@ -71,7 +71,7 @@ void checkRecognition(const vector& xyMoves, const int x_size, const int } } -void checkStartPosRecognition(const string& description, const int expectedStartPos, const bool startPosIsRandom, const string& inputBoard) { +static void checkStartPosRecognition(const string& description, const int expectedStartPos, const bool startPosIsRandom, const string& inputBoard) { const Board board = parseDotsField(inputBoard, startPosIsRandom, Rules::DEFAULT_DOTS.multiStoneSuicideLegal, Rules::DEFAULT_DOTS.dotsCaptureEmptyBases, Rules::DEFAULT_DOTS.dotsFreeCapturedDots, {}); cout << " " << description << " (" << to_string(board.x_size) << "," << to_string(board.y_size) << ")"; @@ -79,7 +79,7 @@ void checkStartPosRecognition(const string& description, const int expectedStart writeToSgfAndCheckStartPosFromSgfProp(expectedStartPos, startPosIsRandom, board); } -void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { +static void checkGenerationAndRecognition(const int startPos, const int startPosIsRandom) { const auto generatedMoves = Rules::generateStartPos(startPos, startPosIsRandom ? &DOTS_RANDOM : nullptr, 39, 32); vector actualStartPosMoves; bool actualRandomized; diff --git a/cpp/tests/testdotsstress.cpp b/cpp/tests/testdotsstress.cpp index 1847171cd..c6da9a6df 100644 --- a/cpp/tests/testdotsstress.cpp +++ b/cpp/tests/testdotsstress.cpp @@ -7,7 +7,7 @@ using namespace std; using namespace std::chrono; using namespace TestCommon; -string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { +static string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { Board boardCopy(initialBoard); BoardHistory boardHistory(boardCopy, P_BLACK, boardCopy.rules, 0); for (const Board::MoveRecord& moveRecord : moveRecords) { @@ -22,7 +22,7 @@ string moveRecordsToSgf(const Board& initialBoard, const vector& moveRecords) { +static void validateStatesAndCaptures(const Board& board, const vector& moveRecords) { int expectedNumBlackCaptures = 0; int expectedNumWhiteCaptures = 0; int expectedPlacedDotsCount = -board.rules.getNumOfStartPosStones(); @@ -131,7 +131,7 @@ void validateStatesAndCaptures(const Board& board, const vector((groundingStartCoef + static_cast(rand.nextDouble()) * (groundingEndCoef - groundingStartCoef)) * static_cast(numLegalMoves)); Player pla = P_BLACK; int currentGameMovesCount = 0; for(short randomMove : randomMoves) { diff --git a/cpp/tests/testdotsutils.h b/cpp/tests/testdotsutils.h index 9662a0327..169e1b3b5 100644 --- a/cpp/tests/testdotsutils.h +++ b/cpp/tests/testdotsutils.h @@ -11,13 +11,13 @@ struct XYMove { int y; Player player; - XYMove(const int x, const int y, const Player player) : x(x), y(y), player(player) {} + XYMove(const int newX, const int newY, const Player newPlayer) : x(newX), y(newY), player(newPlayer) {} [[nodiscard]] std::string toString() const { return "(" + to_string(x) + "," + to_string(y) + "," + PlayerIO::colorToChar(player) + ")"; } - Move toMove(int x_size) const; + [[nodiscard]] Move toMove(int x_size) const; }; struct BoardWithMoveRecords { diff --git a/cpp/tests/testnninputs.cpp b/cpp/tests/testnninputs.cpp index a1c7e7c79..a77e99246 100644 --- a/cpp/tests/testnninputs.cpp +++ b/cpp/tests/testnninputs.cpp @@ -91,9 +91,16 @@ static double finalScoreIfGameEndedNow(const BoardHistory& baseHist, const Board //================================================================================================================== //================================================================================================================== -Hash128 fillRowAndGetHash( - int version, - Board& board, const BoardHistory& hist, Player nextPla, MiscNNInputParams nnInputParams, int nnXLen, int nnYLen, bool inputsUseNHWC,float* rowBin, float* rowGlobal +static Hash128 fillRowAndGetHash( + const int version, + const Board& board, const BoardHistory& hist, + const Player nextPla, + const MiscNNInputParams& nnInputParams, + const int nnXLen, + const int nnYLen, + const bool inputsUseNHWC, + float* rowBin, + float* rowGlobal ) { const Hash128 hash = NNInputs::getHash(board,hist,nextPla,nnInputParams); NNInputs::fillRowVN(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); @@ -526,7 +533,7 @@ xxx..xx bool inputsUseNHWC = true; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); out << rowGlobal[c] << " "; } out << endl; @@ -591,7 +598,7 @@ xxx..xx bool inputsUseNHWC = true; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); out << "Pass Hist Channels: "; for(int c = 0; c<5; c++) out << rowGlobal[c] << " "; @@ -660,7 +667,7 @@ xxx..xx bool inputsUseNHWC = true; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,18); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,19); } @@ -816,7 +823,7 @@ o.xoo.x auto run = [&](bool inputsUseNHWC) { MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - const Hash128 hash = fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,board,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,9); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,10); printNNInputHWAndBoard(out,version,board,hist,nnXLen,nnYLen,inputsUseNHWC,rowBin,11); @@ -883,7 +890,7 @@ o.xoo.x Board b = board; MiscNNInputParams nnInputParams; nnInputParams.drawEquivalentWinsForWhite = drawEquivalentWinsForWhite; - const Hash128 hash = fillRowAndGetHash(version,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); + fillRowAndGetHash(version,b,hist,nextPla,nnInputParams,nnXLen,nnYLen,inputsUseNHWC,rowBin,rowGlobal); for(int c = 0; c Date: Sat, 13 Dec 2025 22:47:22 +0100 Subject: [PATCH 42/42] ~ [CI] Pack necessary .dll in resulting Windows artifact --- .github/workflows/build.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 115539b81..d251795a1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,6 +4,7 @@ on: push: branches: [ master ] pull_request: + branches: [ master ] workflow_dispatch: jobs: @@ -149,4 +150,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: katago-windows-opencl - path: cpp/Release/katago.exe + path: cpp/Release/