/* -*- c-file-style: "GNU" -*- */
/*
 * Copyright (C) Telecom SudParis
 * See COPYING in top-level directory.
 */

#ifndef _REENTRANT
#define _REENTRANT
#endif

#include <eztrace-core/eztrace_config.h>
#include <eztrace-core/eztrace_htable.h>
#include <eztrace-lib/eztrace.h>
#include <eztrace-instrumentation/pptrace.h>
#include <eztrace-lib/eztrace_module.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <iostream>
#include <atomic>
#include <map>
#include <stack>

#include "impl/Kokkos_Profiling_Interface.hpp"
#include "ezt_kokkos.hpp"

extern "C" {
#define CURRENT_MODULE kokkos
  DECLARE_CURRENT_MODULE;
}

static volatile int _kokkos_initialized = 0;

static int kokkos_parallel_for_id=-1;
static int kokkos_parallel_scan_id=-1;
static int kokkos_parallel_reduce_id=-1;
static int kokkos_region_id=-1;
static int kokkos_region_name_id=-1;
static int kokkos_dev_id=-1;

namespace KokkosTools {
  namespace EZTraceConnector {

    std::map<string, int> functions;
    thread_local std::stack<int> current_functions;

    void ezt_init_kokkos() {
      if(kokkos_parallel_for_id < 0)
	kokkos_parallel_for_id = ezt_otf2_register_function("Kokkos:parallel_for");
      if(kokkos_parallel_scan_id < 0)
	kokkos_parallel_scan_id = ezt_otf2_register_function("Kokkos:parallel_scan");
      if(kokkos_parallel_reduce_id < 0)
	kokkos_parallel_reduce_id = ezt_otf2_register_function("Kokkos:parallel_reduce");
      if(kokkos_region_id < 0)
	kokkos_region_id = ezt_otf2_register_function("Kokkos:region");
      if(kokkos_region_name_id < 0)
	kokkos_region_name_id = ezt_otf2_register_attribute("Kokkos:region", OTF2_TYPE_STRING);
      if(kokkos_dev_id < 0)
	kokkos_dev_id = ezt_otf2_register_attribute("Kokkos:device", OTF2_TYPE_INT32);
    }
    void kokkosp_init_library(const int loadSeq, const uint64_t interfaceVer,
			      const uint32_t, void *) {
      // Nothing to do for now
      ezt_init_kokkos();
    }

    void kokkosp_finalize_library() {
      // Nothing to do
    }

    void kokkosp_begin_parallel_for(const char *name, const uint32_t dev_id,
				    uint64_t *ptr) {
      if(kokkos_parallel_for_id < 0) ezt_init_kokkos();
      OTF2_AttributeList* attribute_list = OTF2_AttributeList_New();
      OTF2_AttributeList_AddAttribute_string(attribute_list, kokkos_region_name_id, name);
      OTF2_AttributeList_AddAttribute_uint32(attribute_list, kokkos_dev_id, dev_id);
      EZT_OTF2_EvtWriter_Enter(evt_writer, attribute_list, ezt_get_timestamp(),
			       kokkos_parallel_for_id);
    }

    void kokkosp_end_parallel_for(const uint64_t ) {
      EZT_OTF2_EvtWriter_Leave(evt_writer, NULL, ezt_get_timestamp(),
			       kokkos_parallel_for_id);
    }

    void kokkosp_begin_parallel_scan(const char *name, const uint32_t dev_id,
				     uint64_t *) {
      if(kokkos_parallel_scan_id < 0) ezt_init_kokkos();
      OTF2_AttributeList* attribute_list = OTF2_AttributeList_New();
      OTF2_AttributeList_AddAttribute_string(attribute_list, kokkos_region_name_id, name);
      OTF2_AttributeList_AddAttribute_uint32(attribute_list, kokkos_dev_id, dev_id);
      EZT_OTF2_EvtWriter_Enter(evt_writer, attribute_list, ezt_get_timestamp(),
			       kokkos_parallel_scan_id);
    }

    void kokkosp_end_parallel_scan(const uint64_t) {
      EZT_OTF2_EvtWriter_Leave(evt_writer, NULL, ezt_get_timestamp(),
			       kokkos_parallel_scan_id);
    }

    void kokkosp_begin_parallel_reduce(const char *name, const uint32_t dev_id,
				       uint64_t *) {
      if(kokkos_parallel_reduce_id < 0) ezt_init_kokkos();
      OTF2_AttributeList* attribute_list = OTF2_AttributeList_New();
      OTF2_AttributeList_AddAttribute_string(attribute_list, kokkos_region_name_id, name);
      OTF2_AttributeList_AddAttribute_uint32(attribute_list, kokkos_dev_id, dev_id);
      EZT_OTF2_EvtWriter_Enter(evt_writer, attribute_list, ezt_get_timestamp(),
			       kokkos_parallel_reduce_id);
    }

    void kokkosp_end_parallel_reduce(const uint64_t) {
      EZT_OTF2_EvtWriter_Leave(evt_writer, NULL, ezt_get_timestamp(),
			       kokkos_parallel_reduce_id);
    }

    void kokkosp_push_profile_region(const char *name) {
      int function_id = -1;
      auto f = functions.find(name);
      if(f == functions.end()) {
	function_id = ezt_otf2_register_function(name);
	functions[name]=function_id;
      } else {
	function_id = f->second;
      }
      current_functions.push(function_id);
      
      EZT_OTF2_EvtWriter_Enter(evt_writer, NULL, ezt_get_timestamp(),
			       function_id);
    }

    void kokkosp_pop_profile_region() {
      int function_id = current_functions.top();
      current_functions.pop();
      
      EZT_OTF2_EvtWriter_Leave(evt_writer, NULL, ezt_get_timestamp(),
			       function_id);
    }

  };
};


extern "C" {

  namespace impl = KokkosTools::EZTraceConnector;


  static void _init_kokkos_otf2() {
    /* sometimes, ezt_otf2_register_function returns 1 (because eztrace is being
     * initialized). Thus, we may have to call this function several times untils its
     * succeeds.
     */
    static atomic<int> initialized = 0;
    if(initialized) return;

    if(kokkos_parallel_for_id < 0)
      kokkos_parallel_for_id = ezt_otf2_register_function("Kokkos parallel for");

    if(kokkos_parallel_scan_id < 0)
      kokkos_parallel_scan_id = ezt_otf2_register_function("Kokkos parallel scan");

    if(kokkos_parallel_reduce_id < 0)
      kokkos_parallel_reduce_id = ezt_otf2_register_function("Kokkos parallel reduce");

    if(kokkos_region_id < 0)
      kokkos_region_id = ezt_otf2_register_function("Kokkos parallel region");

    if(kokkos_region_name_id < 0)
      kokkos_region_name_id = ezt_otf2_register_attribute("Name", OTF2_TYPE_STRING);

    if(kokkos_dev_id < 0)
      kokkos_dev_id = ezt_otf2_register_attribute("DevId", OTF2_TYPE_UINT32);

    if(kokkos_parallel_for_id != -1 &&
       kokkos_parallel_scan_id != -1 &&
       kokkos_parallel_reduce_id != -1 &&
       kokkos_region_id != -1 &&
       kokkos_region_name_id == -1 ) {
      initialized = 1;
    }
  }

  static void init_kokkos() {
    INSTRUMENT_FUNCTIONS(kokkos);

    if (eztrace_autostart_enabled())
      eztrace_start();

    _init_kokkos_otf2();

    _kokkos_initialized = 1;
  }

  static void finalize_kokkos() {
    _kokkos_initialized = 0;
    //  ezt_kokkos_finalize_threads();
    eztrace_stop();
  }

  /* No function to intercept since we rely on the OMPT interface */
  PPTRACE_START_INTERCEPT_FUNCTIONS(kokkos)
  PPTRACE_END_INTERCEPT_FUNCTIONS(kokkos)

  static void _kokkos_init(void) {
    static int init_called=0;
    if(init_called)
      return;
    init_called=1;

    eztrace_log(dbg_lvl_debug, "eztrace_kokkos constructor starts\n");
    EZT_REGISTER_MODULE(kokkos, "Module for Kokkos",
			init_kokkos, finalize_kokkos);
    eztrace_log(dbg_lvl_debug, "eztrace_kokkos constructor ends\n");
  }


  static void _kokkos_init(void) __attribute__((constructor));

  namespace impl = KokkosTools::EZTraceConnector;

  EXPOSE_INIT(impl::kokkosp_init_library)
  EXPOSE_FINALIZE(impl::kokkosp_finalize_library)
  EXPOSE_BEGIN_PARALLEL_FOR(impl::kokkosp_begin_parallel_for)
  EXPOSE_END_PARALLEL_FOR(impl::kokkosp_end_parallel_for)
  EXPOSE_BEGIN_PARALLEL_SCAN(impl::kokkosp_begin_parallel_scan)
  EXPOSE_END_PARALLEL_SCAN(impl::kokkosp_end_parallel_scan)
  EXPOSE_BEGIN_PARALLEL_REDUCE(impl::kokkosp_begin_parallel_reduce)
  EXPOSE_END_PARALLEL_REDUCE(impl::kokkosp_end_parallel_reduce)
  EXPOSE_PUSH_REGION(impl::kokkosp_push_profile_region)
  EXPOSE_POP_REGION(impl::kokkosp_pop_profile_region)

}  // extern "C"

