Skip to content
Snippets Groups Projects
Select Git revision
  • c2805a9208f68b43eda36900f8f5d85131adc2a8
  • master default protected
2 results

fd_kokkos.cpp

Blame
  • fd_kokkos.cpp 7.02 KiB
    #include <iostream>
    #include <fstream>
    #include <sstream>
    #include <vector>
    #include <cmath>
    #include <chrono>
    #include <type_traits>
    
    #include <silo.h>
    #include <Kokkos_Core.hpp>
    
    #include <pmmintrin.h>
    #include <xmmintrin.h>
    
    #define WAVE_8_HALO_SIZE 4
    
    using namespace Kokkos;
    using namespace std::chrono;
    
    template<typename T>
    int
    visit_dump(const Kokkos::View<T**>& kv, const std::string& fn)
    {
        static_assert(std::is_same<T,double>::value or std::is_same<T,float>::value,
                      "Only double or float");
    
        DBfile *db = nullptr;
        db = DBCreate(fn.c_str(), DB_CLOBBER, DB_LOCAL, "Kokkos test", DB_HDF5);
        if (!db)
        {
            std::cout << "Cannot write simulation output" << std::endl;
            return -1;
        }
    
        auto size_x = kv.extent(0);
        auto size_y = kv.extent(1);
        
        std::vector<T> x(size_x);
        std::vector<T> y(size_y);
        
        for (size_t i = 0; i < size_x; i++)
            x.at(i) = T(i)/(size_x-1);
            
        for (size_t i = 0; i < size_y; i++)
            y.at(i) = T(i)/(size_y-1);
            
        int dims[] = { int(size_y), int(size_y) };
        int ndims = 2;
        T *coords[] = { x.data(), y.data() };
        
        if (std::is_same<T,float>::value)
            DBPutQuadmesh(db, "mesh", NULL, coords, dims, ndims,
                          DB_FLOAT, DB_COLLINEAR, NULL);
    
        if (std::is_same<T,double>::value)
            DBPutQuadmesh(db, "mesh", NULL, coords, dims, ndims,
                          DB_DOUBLE, DB_COLLINEAR, NULL);
        
        std::vector<T> data(x.size() * y.size());
        
        for (size_t i = 0; i < x.size(); i++)
            for (size_t j = 0; j < y.size(); j++)
                data.at(i*y.size()+j) = kv(i,j);
        
        if (std::is_same<T,float>::value)
            DBPutQuadvar1(db, "solution", "mesh", data.data(), dims, ndims,
                          NULL, 0, DB_FLOAT, DB_NODECENT, NULL);
        
        if (std::is_same<T,double>::value)
            DBPutQuadvar1(db, "solution", "mesh", data.data(), dims, ndims,
                          NULL, 0, DB_DOUBLE, DB_NODECENT, NULL);
    
        DBClose(db);
        return 0;
    }
    
    template<typename T>
    struct wave_equation_context_kokkos
    {
        View<T**>   g_prev;
        View<T**>   g_curr;
        View<T**>   g_next;
    
        T       velocity;
        T       damping;
        T       dt;
        int     maxiter;
    
        size_t  rows, cols;
        size_t  grows, gcols;
    
    
        wave_equation_context_kokkos(size_t prows, size_t pcols,
                                     T vel, T damp, T pdt, int pmaxiter)
            : rows(prows), cols(pcols),
              grows(prows+2*WAVE_8_HALO_SIZE), gcols(pcols+2*WAVE_8_HALO_SIZE)
        {
    
            g_prev = View<T**>("g_prev", grows, gcols);
            g_curr = View<T**>("g_curr", grows, gcols);
            g_next = View<T**>("g_next", grows, gcols);
            velocity    = vel;
            damping     = damp;
            dt          = pdt;
            maxiter     = pmaxiter;
    
            init();
        }
    
        bool is_halo(size_t i, size_t j)
        {
            if (i < WAVE_8_HALO_SIZE)
                return true;
            if (j < WAVE_8_HALO_SIZE)
                return true;
            if (i >= rows+WAVE_8_HALO_SIZE)
                return true;
            if (j >= cols+WAVE_8_HALO_SIZE)
                return true;
            return false;
        }
    
        void init(void)
        {
            auto dx = 1./(cols-1);
            auto dy = 1./(rows-1);
    
            for (size_t i = 0; i < grows; i++)
            {
                for (size_t j = 0; j < gcols; j++)
                {
                    if ( is_halo(i,j) )
                    {
                        g_prev(i,j) = 0.0;
                        g_curr(i,j) = 0.0;
                        g_next(i,j) = 0.0; 
                    }
                    else
                    {
                        T y = dy*i - 0.3;
                        T x = dx*j - 0.1;
                        g_prev(i,j) = -std::exp(-2400*(x*x + y*y));
                        g_curr(i,j) = 2*dt*g_prev(i,j);
                        g_next(i,j) = 0.0;
                    }
                }
            }
        }
    };
    
    
    template<typename T>
    double solve_kokkos(wave_equation_context_kokkos<T>& wec)
    {
        int maxrow  = wec.rows;
        int maxcol  = wec.cols;
        T   dt      = wec.dt;
        T   c       = wec.velocity;
        T   a       = wec.damping;
    
        assert(maxcol > 1);
        assert(maxrow > 1);
    
        // specifying tiling explicitly worsens perf
        MDRangePolicy<Rank<2>> range({0,0}, {maxrow-1, maxcol-1});
    
        T kx2 = c*c * dt*dt * (maxcol-1)*(maxcol-1);
        T ky2 = c*c * dt*dt * (maxrow-1)*(maxrow-1);
        T one_minus_adt = (1.0 - a*dt);
        T two_minus_adt = (2.0 - a*dt);
    
        double iter_time = 0.0;
        for (int ts = 0; ts < wec.maxiter; ts++)
        {
            auto t_begin = high_resolution_clock::now();
    
            parallel_for(range, KOKKOS_LAMBDA(int i, int j)
            {
                static const T w0 = -205.0/72.0;
                static const T w1 =    8.0/5.0;
                static const T w2 =   -1.0/5.0;
                static const T w3 =    8.0/315.0;
                static const T w4 =   -1.0/560.0;
                static const T w[9] = { w4, w3, w2, w1, w0, w1, w2, w3, w4 };
    
    #ifdef DISALLOW_DENORMALS
                _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
                _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
    #endif
                
                i += WAVE_8_HALO_SIZE;
                j += WAVE_8_HALO_SIZE;
    
                T lapl = 0.0;
                for (int k = -WAVE_8_HALO_SIZE; k <= WAVE_8_HALO_SIZE; k++)
                    lapl += kx2 * w[k+WAVE_8_HALO_SIZE] * wec.g_curr(i,j+k);
    
                for (int k = -WAVE_8_HALO_SIZE; k <= WAVE_8_HALO_SIZE; k++)
                    lapl += ky2 * w[k+WAVE_8_HALO_SIZE] * wec.g_curr(i+k,j);
    
                T val = lapl - one_minus_adt * wec.g_prev(i, j) + two_minus_adt * wec.g_curr(i, j);
    
                if ( (i == 0) or 
                     (j == 0) or 
                     (i == maxrow-1) or 
                     (j == maxcol-1)
                   )
                    val = 0;
            
                wec.g_next(i, j) = val;
            });
    
            auto t_end = high_resolution_clock::now();
            
            std::swap(wec.g_prev, wec.g_curr);
            std::swap(wec.g_curr, wec.g_next);
    
            std::chrono::duration<double, std::milli> ms = t_end - t_begin;
            iter_time += ms.count();
    #if 0
            if ( (ts % 100) == 0 )
            {
                std::stringstream ss;
                ss << "wave_kokkos_" << ts << ".silo";
                visit_dump(wec.g_curr, ss.str());
            }
    #endif       
        }
    
        double avg_iter_time = iter_time/wec.maxiter;
        std::cout << "Average iteration time: " << avg_iter_time << std::endl;
        return avg_iter_time;
    }
    
    int main(int argc, char *argv[])
    {
    #ifdef SINGLE_PRECISION
        using T = float;
        std::cout << "Precision: single" << std::endl;
    #else
        using T = double;
        std::cout << "Precision: single" << std::endl;
    #endif
    
    #ifdef DISALLOW_DENORMALS
        _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
        _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
        std::cout << "Denormals: FTZ and DAZ" << std::endl;
    #endif
        _MM_SET_EXCEPTION_MASK(_MM_GET_EXCEPTION_MASK() & ~_MM_MASK_INVALID);
    
        Kokkos::initialize( argc, argv );
    
        for (size_t sz = 128; sz <= 1024; sz *= 2)
        {
            wave_equation_context_kokkos<T> wec(sz, sz, 1, 0.1, 0.0001, 5000);
            wec.init();
            solve_kokkos(wec);
        }
    
        Kokkos::finalize();
    
        return 0;
    }