// ====================================================================== // \title Os/test/ut/queue/RulesHeaders.cpp // \brief definitions for concurrent running rules // ====================================================================== #include #include #include #include #include #include "Os/Condition.hpp" #include "Os/Mutex.hpp" #include "Os/Task.hpp" #ifndef OS_TEST_CONCURRENT_RULE #define OS_TEST_CONCURRENT_RULE // Forward declaration of the aggregated concurrent rule template class AggregatedConcurrentRule; template class ConcurrentRule : public STest::Rule { friend class AggregatedConcurrentRule; public: ConcurrentRule(const char* const name, AggregatedConcurrentRule& runner) : STest::Rule(name), m_runner(runner) { this->m_runner.add(*this); this->m_condition_value = false; } virtual ~ConcurrentRule() { this->m_runner.remove(*this); } //! Launch this rule asynchronously void action_async(State& state) { this->is_asynchronous = true; this->m_state = &state; Os::Task::Arguments arguments(Os::TaskString(this->getName()), this->run, this); m_task.start(arguments); } //! Static function to "apply" this rule using a task static void run(void* rule_pointer) { ASSERT_NE(rule_pointer, nullptr) << "Passed null input to task"; ConcurrentRule* self = reinterpret_cast*>(rule_pointer); ASSERT_TRUE(self->is_asynchronous) << "Rule " << self->getName() << " not run in aggregated concurrent rule"; ASSERT_TRUE(self->precondition(*self->m_state)) << "precondition failed applying rule " << self->getName(); // Ensure this rule owns the global state Os::ScopeLock lock(self->getLock()); self->action(*self->m_state); } protected: //! \brief wait until notified to take the next step void wait_for_next_step() { ASSERT_TRUE(this->is_asynchronous) << "Rule " << this->getName() << " not run in aggregated concurrent rule"; while (not this->m_condition_value) { this->m_condition.wait(this->getLock()); } this->m_condition_value = false; } //! \brief get lock Os::Mutex& getLock() { return this->m_runner.getLock(); } //! \brief notify another rule by nae, void notify_other(std::string other) { this->m_runner.notify(other); } //! \brief get the condition variable bool getCondition() { return this->m_condition_value; } protected: void join() { this->m_task.join(); } //! \brief notify this rule to take the next step void step() { this->m_condition_value = true; this->m_condition.notify(); } //! Reference to the runner AggregatedConcurrentRule& m_runner; //! Local store of state for concurrent action State* m_state; //! Task to run this rule in asynchronous context Os::Task m_task; //! Condition variable used to set up this rule Os::ConditionVariable m_condition; //! Condition guarded by condition variable std::atomic m_condition_value; //! A check to ensure this rule was launched asynchronously when using the blocking wait function bool is_asynchronous = false; }; template class PseudoRule : public ConcurrentRule { public: PseudoRule(const char* name, AggregatedConcurrentRule& runner) : ConcurrentRule(name, runner) {} bool precondition(const State& state) override { return true; } void action(State& state) override {} void wait_for_next_step() { this->ConcurrentRule::wait_for_next_step(); } }; template class ConcurrentWrapperRule : public ConcurrentRule { public: ConcurrentWrapperRule(AggregatedConcurrentRule& runner, STest::Rule& wrapped, std::string notify, const char* name = nullptr) : ConcurrentRule((name == nullptr) ? wrapped.getName() : name, runner), m_wrapped(wrapped), m_notify(notify) {} bool precondition(const State& state) override { return m_wrapped.precondition(state); } void action(State& state) override { this->wait_for_next_step(); // Wait until told to go this->m_wrapped.apply(state); // Go this->notify_other(this->m_notify); // Notify when done } private: STest::Rule& m_wrapped; std::string m_notify; }; template class AggregatedConcurrentRule : public STest::Rule { public: //! Creation AggregatedConcurrentRule() : STest::Rule("aggregated-rule") {} ~AggregatedConcurrentRule() = default; //! Add a rule (on construction of rule) void add(ConcurrentRule& rule) { Os::ScopeLock lock(this->m_lock); m_rules.push_back(&rule); m_rule_map[rule.getName()] = &rule; } //! Remove a rule (on destruction of rule) void remove(ConcurrentRule& rule) { Os::ScopeLock lock(this->m_lock); m_rules.remove(&rule); // Remove rule from m_rules_map for (auto it = m_rule_map.begin(); it != m_rule_map.end(); it++) { if (it->second == &rule) { m_rule_map.erase(it); break; } } } // Notify a rule by name void notify(std::string& name) { // Notify all matching rules for (auto pair : m_rule_map) { if (std::string(pair.second->getName()) == name) { pair.second->step(); return; } } ASSERT_TRUE(false) << "Failed to find rule to notify"; } Os::Mutex& getLock() { return this->m_lock; } //! Aggregate rule may only run if bool precondition(const State& state) override { for (ConcurrentRule* rule : m_rules) { if (not rule->precondition(state)) { return false; } } return true; }; //! Action is to run all rules then wait for all rules to finish void action(State& state) override { // Launch asynchronous rules { Os::ScopeLock lock(this->m_lock); for (ConcurrentRule* rule : m_rules) { rule->action_async(state); } } } void join() { // Wait for asynchronous rules to finish for (ConcurrentRule* rule : m_rules) { rule->join(); } } private: Os::Mutex m_lock; std::list*> m_rules; std::map*> m_rule_map; }; #endif // OS_TEST_CONCURRENT_RULE