Skip to content
Snippets Groups Projects
Commit f3b527f2 authored by Matteo Cicuttin's avatar Matteo Cicuttin
Browse files

Barrier.

parent 37fc983f
No related branches found
No related tags found
No related merge requests found
...@@ -126,7 +126,7 @@ endif() ...@@ -126,7 +126,7 @@ endif()
set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_DEBUG "-g")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native -g -DNDEBUG") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native -fopenmp -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_RELEASEASSERT "-O3 -march=native -g -fpermissive") set(CMAKE_CXX_FLAGS_RELEASEASSERT "-O3 -march=native -g -fpermissive")
macro(setup_fd_catalog_target FD_TGT_NAME SINGLE_PRECISION) macro(setup_fd_catalog_target FD_TGT_NAME SINGLE_PRECISION)
......
...@@ -56,9 +56,9 @@ int main(int argc, char **argv) ...@@ -56,9 +56,9 @@ int main(int argc, char **argv)
time = solve_sequential(wec); time = solve_sequential(wec);
ofs << time << " "; ofs << time << " ";
wec.init(); //wec.init();
time = solve_sequential_blocked(wec); //time = solve_sequential_blocked(wec);
ofs << time << " "; //ofs << time << " ";
#ifdef HAVE_CUDA #ifdef HAVE_CUDA
wec.init(); wec.init();
......
...@@ -378,6 +378,7 @@ wave_2D_kernel(const fd_grid<T>& g_prev, const fd_grid<T>& g_curr, ...@@ -378,6 +378,7 @@ wave_2D_kernel(const fd_grid<T>& g_prev, const fd_grid<T>& g_curr,
T one_minus_adt = (1.0 - a*dt); T one_minus_adt = (1.0 - a*dt);
T two_minus_adt = (2.0 - a*dt); T two_minus_adt = (2.0 - a*dt);
//#pragma omp parallel for
for (size_t i = from; i < maxrow; i+=to) for (size_t i = from; i < maxrow; i+=to)
{ {
#pragma clang loop vectorize(enable) #pragma clang loop vectorize(enable)
...@@ -527,6 +528,88 @@ public: ...@@ -527,6 +528,88 @@ public:
}; };
class Barrier {
public:
explicit Barrier(std::size_t iCount) :
mThreshold(iCount),
mCount(iCount),
mGeneration(0) {
}
void Wait() {
std::unique_lock<std::mutex> lLock{mMutex};
auto lGen = mGeneration;
if (!--mCount) {
mGeneration++;
mCount = mThreshold;
mCond.notify_all();
} else {
mCond.wait(lLock, [this, lGen] { return lGen != mGeneration; });
}
}
private:
std::mutex mMutex;
std::condition_variable mCond;
std::size_t mThreshold;
std::size_t mCount;
std::size_t mGeneration;
};
/*
class Barrier
{
private:
std::mutex m_mutex;
std::condition_variable m_cv;
size_t m_count;
const size_t m_initial;
enum State : unsigned char {
Up, Down
};
State m_state;
public:
explicit Barrier(std::size_t count) : m_count{ count }, m_initial{ count }, m_state{ State::Down } { }
/// Blocks until all N threads reach here
void Wait()
{
std::unique_lock<std::mutex> lock{ m_mutex };
if (m_state == State::Down)
{
// Counting down the number of syncing threads
if (--m_count == 0) {
m_state = State::Up;
m_cv.notify_all();
}
else {
m_cv.wait(lock, [this] { return m_state == State::Up; });
}
}
else // (m_state == State::Up)
{
// Counting back up for Auto reset
if (++m_count == m_initial) {
m_state = State::Down;
m_cv.notify_all();
}
else {
m_cv.wait(lock, [this] { return m_state == State::Down; });
}
}
}
};
*/
#define USE_SPINLOCK #define USE_SPINLOCK
template<typename T> template<typename T>
double solve_multithread(wave_equation_context<T>& wec, size_t nths) double solve_multithread(wave_equation_context<T>& wec, size_t nths)
...@@ -539,6 +622,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -539,6 +622,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
params.velocity = wec.velocity; params.velocity = wec.velocity;
params.damping = wec.damping; params.damping = wec.damping;
Barrier pb(nths+1), cb(nths+1);
/* Multithreading stuff */ /* Multithreading stuff */
#ifdef USE_SPINLOCK #ifdef USE_SPINLOCK
...@@ -561,13 +645,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -561,13 +645,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
while (1) while (1)
{ {
#ifdef USE_SPINLOCK #ifdef USE_SPINLOCK
splock.lock(); pb.Wait();
int done = thread_done[thread_id];
splock.unlock();
if (done)
continue;
if (iteration_finished) if (iteration_finished)
return; return;
#else #else
...@@ -590,16 +668,15 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -590,16 +668,15 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
/* Work for this thread finished, notify producer */ /* Work for this thread finished, notify producer */
#ifdef USE_SPINLOCK #ifdef USE_SPINLOCK
splock.lock();
thread_done[thread_id] = 1;
times[thread_id] += ms.count(); times[thread_id] += ms.count();
splock.unlock(); cb.Wait();
#else #else
std::unique_lock<std::mutex> lck(cv_mtx); std::unique_lock<std::mutex> lck(cv_mtx);
prod_cv.notify_one(); prod_cv.notify_one();
thread_done[thread_id] = 1; thread_done[thread_id] = 1;
times[thread_id] += ms.count(); times[thread_id] += ms.count();
#endif /* USE_SPINLOCK */ #endif /* USE_SPINLOCK */
} }
}; };
...@@ -617,23 +694,8 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -617,23 +694,8 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
#ifdef USE_SPINLOCK #ifdef USE_SPINLOCK
splock.lock(); pb.Wait();
for (auto& td : thread_done) cb.Wait();
td = 0;
splock.unlock();
while(1)
{
int ttd = 0;
splock.lock();
for (auto& td : thread_done)
ttd += td;
splock.unlock();
if (ttd == nths)
break;
}
#else #else
std::unique_lock<std::mutex> lck(cv_mtx); std::unique_lock<std::mutex> lck(cv_mtx);
/* Mark data ready and start the threads */ /* Mark data ready and start the threads */
...@@ -645,6 +707,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -645,6 +707,7 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
prod_cv.wait(lck); prod_cv.wait(lck);
#endif /* USE_SPINLOCK */ #endif /* USE_SPINLOCK */
auto stop = std::chrono::high_resolution_clock::now(); auto stop = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> ms = stop - start; std::chrono::duration<double, std::milli> ms = stop - start;
time += ms.count(); time += ms.count();
...@@ -666,11 +729,9 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths) ...@@ -666,11 +729,9 @@ double solve_multithread(wave_equation_context<T>& wec, size_t nths)
/* Tell all the threads to stop */ /* Tell all the threads to stop */
#ifdef USE_SPINLOCK #ifdef USE_SPINLOCK
splock.lock();
for (size_t i = 0; i < nths; i++)
thread_done[i] = 0;
iteration_finished = true; iteration_finished = true;
splock.unlock(); pb.Wait();
#else #else
{ {
std::unique_lock<std::mutex> lck(cv_mtx); std::unique_lock<std::mutex> lck(cv_mtx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment