diff --git a/PWGJE/Core/MlResponseHfTagging.h b/PWGJE/Core/MlResponseHfTagging.h index 3992b2108c3..be16b0cd450 100644 --- a/PWGJE/Core/MlResponseHfTagging.h +++ b/PWGJE/Core/MlResponseHfTagging.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -361,18 +362,20 @@ class GNNBjetAllocator : public TensorAllocator std::vector> edgesList; + std::function tfFunc; + // Jet feature normalization template T jetFeatureTransform(T feat, int idx) const { - return std::tanh((feat - tfJetMean[idx]) / tfJetStdev[idx]); + return tfFunc((feat - tfJetMean[idx]) / tfJetStdev[idx]); } // Track feature normalization template T trkFeatureTransform(T feat, int idx) const { - return std::tanh((feat - tfTrkMean[idx]) / tfTrkStdev[idx]); + return tfFunc((feat - tfTrkMean[idx]) / tfTrkStdev[idx]); } // Edge input of GNN (fully-connected graph) @@ -419,10 +422,17 @@ class GNNBjetAllocator : public TensorAllocator } public: - GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {} - GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector& tfJetMean, std::vector& tfJetStdev, std::vector& tfTrkMean, std::vector& tfTrkStdev, int64_t maxNNodes = 40) - : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev) + GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40), tfFunc([](float x) { return x; }) {} + GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector& tfJetMean, std::vector& tfJetStdev, std::vector& tfTrkMean, std::vector& tfTrkStdev, int64_t maxNNodes = 40, std::string tfFuncType = "linear") + : TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev), tfFunc([](float x) { return x; }) { + if (tfFuncType == "asinh") { + tfFunc = [](float x) { return std::asinh(x); }; + } else if (tfFuncType == "tanh") { + tfFunc = [](float x) { return std::tanh(x); }; + } else { + tfFunc = [](float x) { return x; }; + } setEdgesList(); } ~GNNBjetAllocator() = default; @@ -439,6 +449,8 @@ class GNNBjetAllocator : public TensorAllocator tfJetStdev = other.tfJetStdev; tfTrkMean = other.tfTrkMean; tfTrkStdev = other.tfTrkStdev; + tfFunc = other.tfFunc; + edgesList.clear(); setEdgesList(); return *this; } diff --git a/PWGJE/TableProducer/jetTaggerHF.cxx b/PWGJE/TableProducer/jetTaggerHF.cxx index e7a5932ea7a..90412b53207 100644 --- a/PWGJE/TableProducer/jetTaggerHF.cxx +++ b/PWGJE/TableProducer/jetTaggerHF.cxx @@ -139,6 +139,7 @@ struct JetTaggerHFTask { Configurable> transformFeatureTrkStdev{"transformFeatureTrkStdev", std::vector{-999}, "Stdev values for each GNN input feature (track)"}; + Configurable tfFuncTypeGNN{"tfFuncTypeGNN", "linear", "Transformation function type for GNN"}; // axis spec ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""}; @@ -525,7 +526,7 @@ struct JetTaggerHFTask { } if (doprocessAlgorithmGNN) { - tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst); + tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst, tfFuncTypeGNN.value); registry.add("h2_count_db", "#it{D}_{b} underflow/overflow;Jet flavour;#it{D}_{b} range", {HistType::kTH2F, {{4, 0., 4.}, {3, 0., 3.}}}); auto h2CountDb = registry.get(HIST("h2_count_db")); diff --git a/PWGJE/Tasks/bjetTaggingGnn.cxx b/PWGJE/Tasks/bjetTaggingGnn.cxx index c10e43afe4d..03ef54ab2e8 100644 --- a/PWGJE/Tasks/bjetTaggingGnn.cxx +++ b/PWGJE/Tasks/bjetTaggingGnn.cxx @@ -809,80 +809,6 @@ struct BjetTaggingGnn { registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_inelgt0"), analysisJet.pt(), mcpjetpT, isTrueINELgt0 && (hasAll(evtselCode, EvtSelFlag::INELgt0rec)) ? weightEvt : 0.0); } } - - // switch (evtselCode) { - // case static_cast(EvtSel::INELgt0rec: - // registry.fill(HIST("h_jetpT_inelgt0rec"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_inelgt0"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_inelgt0"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_inelgt0"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // } - // case static_cast(EvtSel::Sel8Zvtx: - // registry.fill(HIST("h_jetpT_sel8_zvtx"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_sel8"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_sel8_zvtx"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_sel8"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // } - // case static_cast(EvtSel::SelMCZvtx: - // registry.fill(HIST("h_jetpT_selmc_zvtx"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_selmc"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_selmc_zvtx"), analysisJet.pt(), weightEvt); - // if (isMatched) { - // registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_selmc"), analysisJet.pt(), mcpjetpT, weightEvt); - // } - // } - // case static_cast(EvtSel::TVXZvtx: - // registry.fill(HIST("h_jetpT_tvx_zvtx"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_tvx_zvtx"), analysisJet.pt(), weightEvt); - // } - // case static_cast(EvtSel::CollZvtx: - // registry.fill(HIST("h_jetpT_coll_zvtx"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_coll_zvtx"), analysisJet.pt(), weightEvt); - // } - // default: - // switch (evtselCode) { - // case static_cast(EvtSel::Sel8: - // case static_cast(EvtSel::Sel8Zvtx: - // registry.fill(HIST("h_jetpT_sel8"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_sel8"), analysisJet.pt(), weightEvt); - // } - // case static_cast(EvtSel::SelMC: - // case static_cast(EvtSel::SelMCZvtx: - // registry.fill(HIST("h_jetpT_selmc"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_selmc"), analysisJet.pt(), weightEvt); - // } - // case static_cast(EvtSel::TVX: - // case static_cast(EvtSel::TVXZvtx: - // registry.fill(HIST("h_jetpT_tvx"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_tvx"), analysisJet.pt(), weightEvt); - // } - // case static_cast(EvtSel::Coll: - // case static_cast(EvtSel::CollZvtx: - // default: - // registry.fill(HIST("h_jetpT_coll"), analysisJet.pt(), weightEvt); - // if (isBjet) { - // registry.fill(HIST("h_jetpT_b_coll"), analysisJet.pt(), weightEvt); - // } - // } - // } } } PROCESS_SWITCH(BjetTaggingGnn, processMCDJetsSel, "jet information in MC (event selection)", false);