diff --git a/kokkos-testing/fd_catalog/fd_wave_cpu.hpp b/kokkos-testing/fd_catalog/fd_wave_cpu.hpp index e82cebff0ea61d70072b74f672a99159c4325636..88a5f7651457b54a7e88015538b8ab3c5eeb2cfd 100644 --- a/kokkos-testing/fd_catalog/fd_wave_cpu.hpp +++ b/kokkos-testing/fd_catalog/fd_wave_cpu.hpp @@ -510,18 +510,24 @@ double solve_sequential_blocked(wave_equation_context<T>& wec) return solve_sequential_aux<T,true>(wec); } -class SpinLock { - std::atomic_flag locked = ATOMIC_FLAG_INIT ; +class spin_lock +{ + std::atomic_flag flag = ATOMIC_FLAG_INIT; + public: - void lock() { - while (locked.test_and_set(std::memory_order_acquire)) { ; } + void lock() + { + while ( flag.test_and_set(std::memory_order_acquire) ) + ; } - void unlock() { - locked.clear(std::memory_order_release); + void unlock() + { + flag.clear(std::memory_order_release); } }; +#define USE_SPINLOCK template<typename T> double solve_multithread(wave_equation_context<T>& wec, size_t nths) { @@ -535,10 +541,15 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) /* Multithreading stuff */ +#ifdef USE_SPINLOCK + spin_lock splock; + #define GUARDED_BLOCK std::lock_guard<spin_lock> lg(splock); +#else std::mutex cv_mtx; std::condition_variable prod_cv; std::condition_variable cons_cv; - std::vector<bool> thread_done(nths); +#endif /* USE_SPINLOCK */ + std::vector<int> thread_done(nths); std::vector<double> times(nths); bool iteration_finished = false; @@ -549,6 +560,17 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) #endif while (1) { +#ifdef USE_SPINLOCK + { + GUARDED_BLOCK; + if (thread_done[thread_id]) + continue; + + if (iteration_finished) + return; + } + +#else /* Wait for the producer to notify that there's something to do */ { std::unique_lock<std::mutex> lck(cv_mtx); @@ -558,6 +580,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) if (iteration_finished) return; } +#endif /* USE_SPINLOCK */ /* Do the timestep */ auto start = std::chrono::high_resolution_clock::now(); @@ -566,13 +589,23 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) std::chrono::duration<double, std::milli> ms = stop - start; /* Work for this thread finished, notify producer */ +#ifdef USE_SPINLOCK + { + GUARDED_BLOCK; + thread_done[thread_id] = true; + times[thread_id] += ms.count(); + } +#else std::unique_lock<std::mutex> lck(cv_mtx); prod_cv.notify_one(); thread_done[thread_id] = true; times[thread_id] += ms.count(); +#endif /* USE_SPINLOCK */ } }; + for (auto& td : thread_done) + td = 1; std::vector<std::thread> threads(nths); for (size_t i = 0; i < nths; i++) @@ -584,14 +617,33 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) { auto start = std::chrono::high_resolution_clock::now(); +#ifdef USE_SPINLOCK + { + GUARDED_BLOCK; + for (auto& td : thread_done) + td = 0; + } + + while(1) + { + int ttd = 0; + GUARDED_BLOCK; + for (auto& td : thread_done) + ttd += td; + + if (ttd == nths) + break; + } +#else std::unique_lock<std::mutex> lck(cv_mtx); /* Mark data ready and start the threads */ - for (size_t i = 0; i < nths; i++) - thread_done[i] = false; + for (auto& td : thread_done) + td = 0; cons_cv.notify_all(); - while ( !std::all_of(thread_done.begin(), thread_done.end(), [](bool x) -> bool { return x; } ) ) + while ( !std::all_of(thread_done.begin(), thread_done.end(), [](int x) -> bool { return x == 1; } ) ) prod_cv.wait(lck); +#endif /* USE_SPINLOCK */ auto stop = std::chrono::high_resolution_clock::now(); std::chrono::duration<double, std::milli> ms = stop - start; @@ -613,6 +665,14 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) } /* Tell all the threads to stop */ +#ifdef USE_SPINLOCK + { + GUARDED_BLOCK; + for (size_t i = 0; i < nths; i++) + thread_done[i] = false; + iteration_finished = true; + } +#else { std::unique_lock<std::mutex> lck(cv_mtx); for (size_t i = 0; i < nths; i++) @@ -620,6 +680,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) iteration_finished = true; cons_cv.notify_all(); } +#endif /* Wait for all the threads to finish */ for (auto& th : threads)