From beb73f2f07c518c1452ddcf970e70411129ac101 Mon Sep 17 00:00:00 2001 From: koparasy Date: Thu, 18 Dec 2025 11:37:56 -0800 Subject: [PATCH 1/3] Add policy concept --- src/AMSlib/wf/policy.hpp | 33 +++++++++ tests/AMSlib/wf/CMakeLists.txt | 3 + tests/AMSlib/wf/policy.cpp | 132 +++++++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 src/AMSlib/wf/policy.hpp create mode 100644 tests/AMSlib/wf/policy.cpp diff --git a/src/AMSlib/wf/policy.hpp b/src/AMSlib/wf/policy.hpp new file mode 100644 index 00000000..5db49491 --- /dev/null +++ b/src/AMSlib/wf/policy.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "wf/pipeline.hpp" + +namespace ams +{ + +namespace ml +{ +class InferenceModel; +} + +class LayoutTransform; + +/// Policies are factories that construct Pipelines. +/// +/// A Policy encodes *what* should happen (control flow, fallback strategy), +/// while the Pipeline and Actions encode *how* it happens. +class Policy +{ +public: + virtual ~Policy() = default; + + /// Construct a pipeline for the given model and layout The, potentially + /// nullable, Model is a non-owning pointer. + /// + /// The returned Pipeline is ready to run. + virtual Pipeline makePipeline(const ml::InferenceModel* Model, + LayoutTransform& Layout) const = 0; + virtual const char* name() const noexcept = 0; +}; + +} // namespace ams diff --git a/tests/AMSlib/wf/CMakeLists.txt b/tests/AMSlib/wf/CMakeLists.txt index 2c825b3c..4e3f98fb 100644 --- a/tests/AMSlib/wf/CMakeLists.txt +++ b/tests/AMSlib/wf/CMakeLists.txt @@ -62,3 +62,6 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::ACTION action) BUILD_UNIT_TEST(pipeline pipeline.cpp) ADD_WORKFLOW_UNIT_TEST(WORKFLOW::PIPELINE pipeline) + +BUILD_UNIT_TEST(policy policy.cpp) +ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POLICY policy) diff --git a/tests/AMSlib/wf/policy.cpp b/tests/AMSlib/wf/policy.cpp new file mode 100644 index 00000000..2f40c64f --- /dev/null +++ b/tests/AMSlib/wf/policy.cpp @@ -0,0 +1,132 @@ +#include "wf/policy.hpp" + +#include +#include +#include +#include + +#include "ml/Model.hpp" +#include "wf/action.hpp" +#include "wf/eval_context.hpp" +#include "wf/layout_transform.hpp" +#include "wf/pipeline.hpp" + +namespace ams +{ + +namespace +{ + +class IncAction final : public Action +{ +public: + const char* name() const noexcept override { return "IncAction"; } + AMSStatus run(EvalContext& Ctx) override + { + Ctx.Threshold = Ctx.Threshold.value_or(0.0f) + 1.0f; + return {}; + } +}; + +class FailAction final : public Action +{ +public: + const char* name() const noexcept override { return "FailAction"; } + AMSStatus run(EvalContext&) override + { + return AMS_MAKE_ERROR(AMSErrorType::Generic, "FailAction triggered"); + } +}; + +class DummyLayout final : public LayoutTransform +{ +public: + const char* name() const noexcept override { return "DummyLayout"; } + + // Only needed if you included the real LayoutTransform interface. +#if __has_include("wf/layout_transform.hpp") + AMSExpected pack(const TensorBundle&, + const TensorBundle&, + at::Tensor&) override + { + return IndexMap{}; + } + AMSStatus unpack(const torch::jit::IValue&, + TensorBundle&, + TensorBundle&, + std::optional&) override + { + return {}; + } +#endif +}; + +class DirectLikePolicy final : public Policy +{ +public: + const char* name() const noexcept override { return "DirectLikePolicy"; } + + Pipeline makePipeline(const ml::InferenceModel* /*Model*/, + LayoutTransform& /*Layout*/) const override + { + Pipeline P; + P.add(std::make_unique()).add(std::make_unique()); + return P; + } +}; + +class FailingPolicy final : public Policy +{ +public: + const char* name() const noexcept override { return "FailingPolicy"; } + + Pipeline makePipeline(const ml::InferenceModel* /*Model*/, + LayoutTransform& /*Layout*/) const override + { + Pipeline P; + P.add(std::make_unique()) + .add(std::make_unique()) + .add(std::make_unique()); // must not run + return P; + } +}; + +} // namespace + +CATCH_TEST_CASE("Policy is an abstract factory for Pipelines", "[wf][policy]") +{ + CATCH_STATIC_REQUIRE(std::is_abstract_v); + CATCH_STATIC_REQUIRE(std::has_virtual_destructor_v); + + DummyLayout L; + ml::InferenceModel* Model = nullptr; + + DirectLikePolicy Pol; + CATCH_REQUIRE(std::string(Pol.name()) == "DirectLikePolicy"); + + EvalContext Ctx{}; + auto P = Pol.makePipeline(Model, L); + + auto St = P.run(Ctx); + CATCH_REQUIRE(St); + CATCH_REQUIRE(Ctx.Threshold == 2.0f); +} + +CATCH_TEST_CASE("Policy-built pipeline short-circuits on Action failure", + "[wf][policy]") +{ + DummyLayout L; + ml::InferenceModel* Model = nullptr; + + FailingPolicy Pol; + EvalContext Ctx{}; + + auto P = Pol.makePipeline(Model, L); + auto St = P.run(Ctx); + + CATCH_REQUIRE_FALSE(St); + CATCH_REQUIRE(St.error().getType() == AMSErrorType::Generic); + CATCH_REQUIRE(Ctx.Threshold == 1.0f); +} + +} // namespace ams From b0c1295564827a104a12170d8dcbcb88c8990c53 Mon Sep 17 00:00:00 2001 From: Konstantinos Parasyris Date: Thu, 18 Dec 2025 12:46:27 -0800 Subject: [PATCH 2/3] Update src/AMSlib/wf/policy.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/AMSlib/wf/policy.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AMSlib/wf/policy.hpp b/src/AMSlib/wf/policy.hpp index 5db49491..7d020404 100644 --- a/src/AMSlib/wf/policy.hpp +++ b/src/AMSlib/wf/policy.hpp @@ -21,7 +21,7 @@ class Policy public: virtual ~Policy() = default; - /// Construct a pipeline for the given model and layout The, potentially + /// Construct a pipeline for the given model and layout. The, potentially /// nullable, Model is a non-owning pointer. /// /// The returned Pipeline is ready to run. From 67c2a70f0b7dab5efc70f64f42bcda8347fd9fc6 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:06:26 -0800 Subject: [PATCH 3/3] Remove redundant __has_include check in policy test (#184) * Initial plan * Remove unnecessary __has_include preprocessor check Co-authored-by: koparasy <1258022+koparasy@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: koparasy <1258022+koparasy@users.noreply.github.com> --- tests/AMSlib/wf/policy.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/AMSlib/wf/policy.cpp b/tests/AMSlib/wf/policy.cpp index 2f40c64f..853cba8d 100644 --- a/tests/AMSlib/wf/policy.cpp +++ b/tests/AMSlib/wf/policy.cpp @@ -43,8 +43,6 @@ class DummyLayout final : public LayoutTransform public: const char* name() const noexcept override { return "DummyLayout"; } - // Only needed if you included the real LayoutTransform interface. -#if __has_include("wf/layout_transform.hpp") AMSExpected pack(const TensorBundle&, const TensorBundle&, at::Tensor&) override @@ -58,7 +56,6 @@ class DummyLayout final : public LayoutTransform { return {}; } -#endif }; class DirectLikePolicy final : public Policy