// -*- C++ -*-
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/FastJets.hh"
#include "Rivet/Projections/DISKinematics.hh"
#include "Rivet/Projections/ChargedFinalState.hh"
#include "Rivet/Projections/Sphericity.hh"
#include "Rivet/Projections/Spherocity.hh"
#include "Rivet/Projections/Thrust.hh"
#include "Rivet/Projections/FastJets.hh"
#include "Rivet/Projections/ParisiTensor.hh"
#include "Rivet/Projections/Hemispheres.hh"

namespace Rivet {


  /// @brief Monte Carlo validation observables for jet photoporudction
  class MC_PHOTOPRODUCTION : public Analysis {
  public:

    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(MC_PHOTOPRODUCTION);


    /// @name Analysis methods
    /// @{

    /// Book histograms and initialise projections before the run
    void init() {

      // set CFS cuts from input options
      _fideta = getOption<double>("ABSETAMAX", 4.0);
      _fidpt = getOption<double>("PTMIN", 6.);

      // set clustering radius from input option
      const double R = getOption<double>("R", 1.0);

      // set clustering algorithm from input option
      JetAlg clusterAlgo;
      const string algoopt = getOption("ALGO", "ANTIKT");
      if ( algoopt == "KT" ) {
        clusterAlgo = JetAlg::KT;
      } else if ( algoopt == "CA" ) {
        clusterAlgo = JetAlg::CA;
      } else if ( algoopt == "ANTIKT" ) {
        clusterAlgo = JetAlg::ANTIKT;
      } else {
        MSG_WARNING("Unknown jet clustering algorithm option " + algoopt + ". Defaulting to anti-kT");
        clusterAlgo = JetAlg::ANTIKT;
      }

      // Initialise and register projections
      const FinalState fs;
      FastJets jetfs(fs, clusterAlgo, R);
      declare(jetfs, "jets");

      // For the multiplicity
      const ChargedFinalState cfs;
      declare(cfs, "CFS");

      declare(Sphericity(fs, 1.), "Sphericity");
      declare(Spherocity(fs), "Spherocity");
      declare(ParisiTensor(fs), "Parisi");
      const Thrust thrust(fs);
      declare(thrust, "Thrust");
      declare(Hemispheres(thrust), "Hemispheres");

      declare(DISKinematics(), "Kinematics");

      // Book histograms
      // First: Standard photoproduction observables
      book(_h["et"], "ET", 34, 6., 40.);
      book(_h["eta"], "eta", 40, -2., 4.);
      book(_h["x"], "x_y", 40, 0., 1.);
      book(_h["mjj"], "m_jj", 40, 12., 140.);
      book(_h["cosTheta"], "cosThetaStar", 40, 0., 1.);
      // Second: jet-shape observables
      book(_h["thrust"], "thrust", 35, 0.65, 1.);
      book(_h["minor"], "minor", 35, 0., 0.7);
      book(_h["C"], "Cparameter", 50, 0., 0.5);
      book(_h["D"], "Dparameter", 50, 0., 0.06);
      book(_h["rhoH"], "rhoH", 30, 0., 0.5);
      book(_h["sphericity"], "sphericity", logspace(50, 3e-4, 0.3));
      book(_h["transvSphericity"], "transvSphericity", 50, 0., 0.5);
      book(_h["aplanarity"], "aplanarity", logspace(50, 1e-5, 0.01));
      //book(_h["spherocity"], "spherocity", 50, 0., 1.);
      book(_h["totalB"], "totalB", 30, 0., 0.15);
      book(_h["wideB"], "wideB", 30, 0., 0.15);
      book(_h["narrowB"], "narrowB", 30, 0., 0.05);
      book(_h["y12"], "y12", logspace(50, 1e-6, 1.));
      book(_h["y23"], "y23", logspace(50, 1e-6, 1.));
      book(_h["y34"], "y34", logspace(50, 1e-6, 1.));
      std::vector<double> mult_bins;
      for (size_t i = 0; i < 25; ++i) { mult_bins.push_back(2.0*i-0.5); }
      book(_h["mult"], "multiplicity", mult_bins);
      book(_h["mult_fid"], "multiplicity_fiducial", mult_bins);
      // Third: heavy-quark jet observables
      double ptmin = 6., ptmax = 40.;
      double etamin = -2., etamax = 4.;
      double massmin = 12., massmax = 100.;
      int nPTbins = 15, nEtaBins = 20, nMassBins = 20;
      for (const string& suff : vector<string>{"c", "b", "tot"}) {
        book(_h["pt_"+suff], "pt_"+suff, nPTbins, ptmin, ptmax);
        book(_h["eta_"+suff], "eta_"+suff, nEtaBins, etamin, etamax);
        book(_h["mjj_"+suff], "dijet_mass_"+suff, nMassBins, massmin, massmax);
        book(_h["xgamma_"+suff], "xgamma_"+suff, 15, 0., 1.);
        book(_e["pt_"+suff], "ratio_pt_"+suff, nPTbins, ptmin, ptmax);
        book(_e["eta_"+suff], "ratio_eta_"+suff, nEtaBins, etamin, etamax);
        book(_e["mjj_"+suff], "ratio_dijet_mass_"+suff, nMassBins, massmin, massmax);
        if (suff == "tot"s) {
          book(_h["pt_all"], "_all_pt", nPTbins, ptmin, ptmax);
          book(_h["eta_all"], "_all_eta", nEtaBins, etamin, etamax);
          book(_h["mjj_all"], "_all_dijet_mass", nMassBins, massmin, massmax);
        }
      }
    }


    /// Perform the per-event analysis
    void analyze(const Event& event) {
      const DISKinematics& kin = apply<DISKinematics>(event, "Kinematics");
      const FinalState& cfs = apply<FinalState>(event, "CFS");

      if ( kin.failed() ) vetoEvent;
      const int orientation = kin.orientation();
      // Q2 cut
      if (kin.Q2() > 1*GeV2) vetoEvent;
      // Jet selection
      const Jets jets = apply<FastJets>(event, "jets").jets(Cuts::abseta< _fideta && Cuts::Et > _fidpt*GeV, cmpMomByEt);
      if (jets.size() < 2) vetoEvent;
      if (jets[0].Et() < 8*GeV) vetoEvent;

      // ---------- Inclusive jet observables, eta and E_T ----------
      for(const Jet& jet : jets) {
        _h["et"]->fill(jet.Et()/GeV);
        const double eta = orientation*jet.eta();
        _h["eta"]->fill(eta);
      }
      double xyobs=0;
      if (jets.size()>1) {
        const Jet& j1 = jets[0];
        const Jet& j2 = jets[1];
        // Jet eta and cos(theta*) computation
        const double eta1 = orientation*j1.eta(), eta2 = orientation*j2.eta();
        const double costhetastar = tanh((eta1 - eta2)/2);
        // Computation of x_y^obs
        xyobs = (j1.Et() * exp(-eta1) + j2.Et() * exp(-eta2)) / (2*kin.y()*kin.beamLepton().E());
        // Calculate the invariant mass
        const double mjj = (j1.mom()+j2.mom()).mass();

        _h["x"]->fill(xyobs);
        _h["mjj"]->fill(mjj/GeV);
        _h["cosTheta"]->fill(abs(costhetastar));
      }

      // ---------- Event Shapes ----------
      // see arxiv:2301.01086 and arxiv:2007.12600 for definitions
      vector<Vector3> momenta;
      for (const Jet& jet : jets) {
        Vector3 mom = jet.p3();
        mom.setZ(0.);
        momenta.push_back(mom);
      }
      // If only 2 particles, we need to use a ghost so that Thrust.calc() doesn't return 1.
      if (momenta.size() == 2) {
        momenta.push_back(Vector3(1e-10*MeV, 0., 0.));
      }
      // Transverse Thrust
      Thrust thrust = apply<Thrust>(event, "Thrust");
      thrust.calc(momenta);
      _h["thrust"]->fill(thrust.thrust());
      _h["minor"]->fill(thrust.thrustMajor());

      // Sphericity
      // as defined by ATLAS in arxiv:2007.12600; please note that the transverse sphericity in arxiv:2301.01086
      // is defined without the additional factor 2
      Sphericity sphericity = apply<Sphericity>(event, "Sphericity");
      _h["sphericity"]->fill(sphericity.sphericity());
      _h["aplanarity"]->fill(sphericity.aplanarity());
      //Linearized sphericity calculation (2D)
      momenta.clear();
      for (const Jet& jet : jets) {
        Vector3 mom = jet.p3();
        mom.setZ(0.);
        momenta.push_back(mom);
      }
      double a11 = 0.0; double a22 = 0.0;
      double a12 = 0.0; double modSum2 = 0.0;

      for (const Vector3& mom : momenta) {
        modSum2 += mom.mod();
        a11 += mom.x()*mom.x()/mom.mod();
        a22 += mom.y()*mom.y()/mom.mod();
        a12 += mom.x()*mom.y()/mom.mod();
      }

      double trc2 = (a11+a22)/modSum2;
      double det2 = (a11*a22-a12*a12)/pow(modSum2,2);

      double eigen21 = (trc2+sqrt(pow(trc2,2)-4*det2))/2;
      double eigen22 = (trc2-sqrt(pow(trc2,2)-4*det2))/2;
      double transSphericity = 2*eigen22/(eigen21+eigen22);
      _h["transvSphericity"]->fill(transSphericity);

      /*// Spherocity
      // @todo check whether this is correctly defined for hadron-like collisions
      // @todo check the why it always gives 0
      const Spherocity& spherocity = apply<Spherocity>(event, "Spherocity");
      _h["spherocity"]->fill(spherocity.spherocity());
      */

      // C parameter
      // as defined by ATLAS in arxiv:2007.12600
      const ParisiTensor& parisi = apply<ParisiTensor>(event, "Parisi");
      _h["C"]->fill(parisi.C());
      _h["D"]->fill(parisi.D());

      // Hemispheres
      // @todo check whether this is correctly defined for hadron-like collisions
      const Hemispheres& hemi = apply<Hemispheres>(event, "Hemispheres");
      _h["rhoH"]->fill(hemi.scaledMhigh());
      _h["wideB"]->fill(hemi.Bmax());
      _h["narrowB"]->fill(hemi.Bmin());
      _h["totalB"]->fill(hemi.Bsum());

      // Jets
      const FastJets& durjet = apply<FastJets>(event, "jets");
      const double y12 = durjet.clusterSeq()->exclusive_ymerge_max(1);
      _h["y12"]->fill(y12);
      const double y23 = durjet.clusterSeq()->exclusive_ymerge_max(2);
      _h["y23"]->fill(y23);
      const double y34 = durjet.clusterSeq()->exclusive_ymerge_max(3);
      _h["y34"]->fill(y34);

      // charged particles
      _h["mult"]->fill(cfs.particles().size());
      _h["mult_fid"]->fill(cfs.particles(Cuts::abseta<4. && Cuts::pT > 230*MeV).size());

      // ---------- Heavy Quarks ----------
      // @todo check the number of tags, to make sure it's == 1?
      if (jets[0].bTagged() || jets[0].cTagged()) {
        string tag(jets[0].bTagged()?"b":"c");
        _h["pt_tot"]->fill(jets[0].pt()/GeV);
        _h["eta_tot"]->fill(orientation*jets[0].eta());
        _h["pt_"+tag]->fill(jets[0].pt()/GeV);
        _h["eta_"+tag]->fill(orientation*jets[0].eta());
        if (jets.size()>1 && (jets[1].bTagged() || jets[1].cTagged())) {
          string sub_tag(jets[1].bTagged()?"b":"c");
          if (tag == sub_tag) {
            const double heavy_mass = (jets[0].mom()+jets[1].mom()).mass();
            _h["mjj_tot"]->fill(heavy_mass/GeV);
            _h["mjj_"+tag]->fill(heavy_mass/GeV);
            _h["xgamma_"+tag]->fill(xyobs);
            _h["xgamma_tot"]->fill(xyobs);
          }
        }
      }
      _h["pt_all"]->fill(jets[0].pt()/GeV);
      _h["eta_all"]->fill(orientation*jets[0].eta());
      if (jets.size()>1) _h["mjj_all"]->fill((jets[0].mom()+jets[1].mom()).mass()/GeV);
    }


    /// Normalise histograms etc., after the run
    void finalize() {
      scale(_h, crossSection()/picobarn/sumOfWeights());
      for (const string& var : vector<string>{"pt", "eta", "mjj"}) {
        for (const string& suff : vector<string>{"c", "b", "tot"}) {
          divide(_h[var+"_"+suff], _h[var+"_all"], _e[var+"_"+suff]);
        }
      }
    }

    /// @}


    /// @name Histograms
    /// @{
    map<string,Histo1DPtr> _h;
    map<string,Estimate1DPtr> _e;
    /// @}

    double _fideta, _fidpt;


  };


  RIVET_DECLARE_PLUGIN(MC_PHOTOPRODUCTION);

}
