///  \author Are Raklev
///  \date 2018 June
///  Based on the search presented in 1806.04030.
///  Only the high mass analysis is implemented here.
///  This analysis has overlapping exclusion and discovery signal regions,
///  the discovery regions are separated into a derived class.
///  *********************************************

#include <vector>
#include <cmath>
#include <memory>
#include <iomanip>

#include "gambit/ColliderBit/analyses/Analysis.hpp"
#include "gambit/ColliderBit/ATLASEfficiencies.hpp"

using namespace std;

// TODO: See if adding muons to jets gives some improvement.

namespace Gambit {
  namespace ColliderBit {

    class Analysis_ATLAS_13TeV_3b_36invfb : public Analysis {

      // Signal region map
      std::map<string, EventCounter> _counters = {
        // Exclusion regions, disjoint
        {"SR-3b-meff1-A", EventCounter("SR-3b-meff1-A")},
        {"SR-3b-meff2-A", EventCounter("SR-3b-meff2-A")},
        {"SR-3b-meff3-A", EventCounter("SR-3b-meff3-A")},
        {"SR-4b-meff1-A", EventCounter("SR-4b-meff1-A")},
        {"SR-4b-meff1-B", EventCounter("SR-4b-meff1-B")},
        {"SR-4b-meff2-A", EventCounter("SR-4b-meff2-A")},
        {"SR-4b-meff2-B", EventCounter("SR-4b-meff2-B")},
        // Discovery regions, SR-4b-meff1-A and SR-4b-meff2-A are subsets
        {"SR-4b-meff1-A-disc", EventCounter("SR-4b-meff1-A-disc")},


      // Cut-flows
      size_t NCUTS;
      vector<int> cutFlowVector;
      vector<string> cutFlowVector_str;
      vector<double> cutFlowVectorATLAS;


      // Required detector sim
      static constexpr const char* detector = "ATLAS";

      static bool sortByPT(const HEPUtils::Jet* jet1, const HEPUtils::Jet* jet2) { return (jet1->pT() > jet2->pT()); }

      Analysis_ATLAS_13TeV_3b_36invfb() {



        for(size_t i=0;i<NCUTS;i++){

      // The following section copied from Analysis_ATLAS_1LEPStop_20invfb.cpp
      void JetLeptonOverlapRemoval(vector<const HEPUtils::Jet*> &jetvec, vector<const HEPUtils::Particle*> &lepvec, double DeltaRMax) {
        //Routine to do jet-lepton check
        //Discards jets if they are within DeltaRMax of a lepton

        vector<const HEPUtils::Jet*> Survivors;

        for(unsigned int itjet = 0; itjet < jetvec.size(); itjet++) {
          bool overlap = false;
          for(unsigned int itlep = 0; itlep < lepvec.size(); itlep++) {
            double dR;


            if(fabs(dR) <= DeltaRMax) overlap=true;
          if(overlap) continue;


      void LeptonJetOverlapRemoval(vector<const HEPUtils::Particle*> &lepvec, vector<const HEPUtils::Jet*> &jetvec) {
        //Routine to do lepton-jet check
        //Discards leptons if they are within dR of a jet as defined in analysis paper

        vector<const HEPUtils::Particle*> Survivors;

        for(unsigned int itlep = 0; itlep < lepvec.size(); itlep++) {
          bool overlap = false;
          for(unsigned int itjet= 0; itjet < jetvec.size(); itjet++) {
            double dR;
            double DeltaRMax = std::min(0.4, 0.04 + 10 / lepmom.pT());

            if(fabs(dR) <= DeltaRMax) overlap=true;
          if(overlap) continue;


      // Calculate transverse mass
      double mTrans(HEPUtils::P4 pmiss, HEPUtils::P4 jet) {
        double mT = sqrt( pow(pmiss.pT()+jet.pT(),2) - pow(pmiss.px()+jet.px(),2) - pow(,2) );
        //cout << "pTmiss " << pmiss.pT() << " jetpT " << jet.pT() << endl;
        //cout << "pxmiss " << pmiss.px() << "pxjet " << jet.px() << " pymiss " << << " pyjet " << << endl;
        return mT;

      void run(const HEPUtils::Event* event) {

        // Get the missing energy in the event
        double met = event->met();
        HEPUtils::P4 metVec = event->missingmom();

        // Now define vectors of baseline objects, including:
        // - retrieval of electron, muon and jets from the event
        // - application of basic pT and eta cuts

        // Electrons
        vector<const HEPUtils::Particle*> electrons;
        for (const HEPUtils::Particle* electron : event->electrons()) {
          if (electron->pT() > 5.
              && fabs(electron->eta()) < 2.47)

        // Apply electron efficiency

        // Muons
        vector<const HEPUtils::Particle*> muons;
        for (const HEPUtils::Particle* muon : event->muons()) {
          if (muon->pT() > 5.
              && fabs(muon->eta()) < 2.5)

        // Apply muon efficiency

        vector<const HEPUtils::Jet*> candJets;
        for (const HEPUtils::Jet* jet : event->jets("antikt_R04")) {
          if (jet->pT() > 20. && fabs(jet->eta()) < 2.8)

        // Jets
        vector<const HEPUtils::Jet*> bJets;
        vector<const HEPUtils::Jet*> nonbJets;

        // Find b-jets
        double btag = 0.77; double cmisstag = 1/6.; double misstag = 1./134.;
        for (const HEPUtils::Jet* jet : candJets) {
          // Tag
          if( jet->btag() && random_bool(btag) ) bJets.push_back(jet);
          // Misstag c-jet
          else if( jet->ctag() && random_bool(cmisstag) ) bJets.push_back(jet);
          // Misstag light jet
          else if( random_bool(misstag) ) bJets.push_back(jet);
          // Non b-jet
          else nonbJets.push_back(jet);

        // Overlap removal

        // Find veto leptons with pT > 20 GeV
        vector<const HEPUtils::Particle*> vetoElectrons;
        for (const HEPUtils::Particle* electron : electrons) {
          if (electron->pT() > 20.) vetoElectrons.push_back(electron);
        vector<const HEPUtils::Particle*> vetoMuons;
        for (const HEPUtils::Particle* muon : muons) {
          if (muon->pT() > 20.) vetoMuons.push_back(muon);

        // Restrict jets to pT > 25 GeV after overlap removal
        vector<const HEPUtils::Jet*> bJets_survivors;
        for (const HEPUtils::Jet* jet : bJets) {
          if(jet->pT() > 25.) bJets_survivors.push_back(jet);
        vector<const HEPUtils::Jet*> nonbJets_survivors;
        for (const HEPUtils::Jet* jet : nonbJets) {
          if(jet->pT() > 25.) nonbJets_survivors.push_back(jet);
        vector<const HEPUtils::Jet*> jet_survivors;
        jet_survivors = nonbJets_survivors;
        for (const HEPUtils::Jet* jet : bJets) {
        std::sort(jet_survivors.begin(), jet_survivors.end(), sortByPT);

        // Number of objects
        size_t nbJets = bJets_survivors.size();
        size_t nnonbJets = nonbJets_survivors.size();
        size_t nJets = nbJets + nnonbJets;
        //size_t nJets = jet_survivors.size();
        size_t nMuons=vetoMuons.size();
        size_t nElectrons=vetoElectrons.size();
        size_t nLeptons = nElectrons+nMuons;

        // Loop over jets to find angle wrt to missing momentum
        double phi4min = 7;
        for(int i = 0; i < min(4,(int)nJets); i++){
          double phi =>mom().deltaPhi(metVec);
          if(phi < phi4min) phi4min = phi;

        // Collect the four signal jets.
        vector<const HEPUtils::Jet*> signalJets;
        for(const HEPUtils::Jet* jet : bJets_survivors){
          if(signalJets.size() < 4) signalJets.push_back(jet);
        for(const HEPUtils::Jet* jet : nonbJets_survivors){
          if(signalJets.size() < 4) signalJets.push_back(jet);

        // Effective mass (using the four jets used in Higgses)
        double meff = met;
        for(const HEPUtils::Jet* jet : signalJets){
          meff += jet->pT();

        // Find Higgs candidates
        double mlead = 0;  double msubl = 0;
        double m1 = 0;  double m2 = 0;
        double Rbbmax = 10;
        if(signalJets.size() == 4){
          double R11 =>mom().deltaR_eta(>mom());
          double R12 =>mom().deltaR_eta(>mom());
          double DR1 = max(R11,R12);
          //cout << DR1 << " " << R11 << " " << R12 << endl;
          double R21 =>mom().deltaR_eta(>mom());
          double R22 =>mom().deltaR_eta(>mom());
          double DR2 = max(R21,R22);
          //cout << DR2 << " " << R21 << " " << R22 << endl;
          double R31 =>mom().deltaR_eta(>mom());
          double R32 =>mom().deltaR_eta(>mom());
          double DR3 = max(R31,R32);
          //cout << DR3 << " " << R31 << " " << R32 << endl;
          //cout << endl;
          if( DR1 < DR2 && DR1 < DR3 ){
            m1 = (>mom()>mom()).m();
            m2 = (>mom()>mom()).m();
            Rbbmax = DR1;
          else if( DR2 < DR1 && DR2 < DR3 ){
            m1 = (>mom()>mom()).m();
            m2 = (>mom()>mom()).m();
            Rbbmax = DR2;
            m1 = (>mom()>mom()).m();
            m2 = (>mom()>mom()).m();
            Rbbmax = DR3;
          mlead = max(m1,m2); msubl = min(m1,m2);
          //cout << mlead << " " << msubl << endl;

        // Transverse mass for leading b-jets
        double mTmin = 10E6;
        for(int i = 0; i < min(3,(int)nbJets); i++){
          double mT = mTrans(metVec,>mom());
          if(mT < mTmin) mTmin = mT;
        //cout << "mTmin " << mTmin << endl;

        // Increment cutFlowVector elements
        // Cut flow strings
        // Apply cutflow
//        for(size_t j=0;j<NCUTS;j++){
//          if(
//             (j==0) ||
//             (j==1 && met > 200.) ||
//             (j==2 && met > 200 && phi4min > 0.4) ||
//             (j==3 && met > 200 && phi4min > 0.4 && nLeptons == 0) ||
//             (j==4 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5)) ||
//             (j==5 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150.) ||
//             (j==6 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140.) ||
//             (j==7 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && mTmin > 130.) ||
//             (j==8 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5)  && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && mTmin > 130. && meff > 1100.) ||
//             (j==9 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5)  && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && mTmin > 130. && meff > 1100. && nbJets >= 3) ||
//             (j==10 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5)  && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && mTmin > 130. && meff > 1100. && nbJets >= 3 && Rbbmax > 0.4 && Rbbmax < 1.4) ||
//             (j==11 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && meff > 600.) ||
//             (j==12 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && meff > 600. && nbJets >= 4) ||
//             (j==13 && met > 200 && phi4min > 0.4 && nLeptons == 0 && (nJets == 4 || nJets == 5) && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && meff > 600. && nbJets >= 4 && Rbbmax > 0.4 && Rbbmax < 1.4)
//             ) cutFlowVector[j]++;
//        }

        // Now increment signal region variables
        // First exclusion regions
        if(nbJets == 3 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && mTmin > 150. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff > 600. && meff < 850.)"SR-3b-meff1-A").add_event(event);
        if(nbJets == 3 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && mTmin > 150. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff > 850. && meff < 1100.)"SR-3b-meff2-A").add_event(event);
        if(nbJets >= 3 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && mTmin > 130. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff > 1100.)"SR-3b-meff3-A").add_event(event);
        if(nbJets >= 4 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && meff > 600. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff < 850.)"SR-4b-meff1-A").add_event(event);
        if(nbJets >= 4 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && meff > 600. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 1.4 && Rbbmax < 2.4 && meff < 850.)"SR-4b-meff1-B").add_event(event);
        if(nbJets >= 4 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 6 && meff > 850. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff < 1100.)"SR-4b-meff2-A").add_event(event);
        if(nbJets >= 4 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 6 && meff > 850. && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 1.4 && Rbbmax < 2.4 && meff < 1100.)"SR-4b-meff2-B").add_event(event);
        // Discovery regions
        if(nbJets >= 4 && met > 200 && nLeptons == 0 && phi4min > 0.4 && nJets >= 4 && nJets <= 5 && mlead > 110. && mlead < 150. && msubl > 90. && msubl < 140. && Rbbmax > 0.4 && Rbbmax < 1.4 && meff > 600.)"SR-4b-meff1-A-disc").add_event(event);


      } // End of analyze

      /// Combine the variables of another copy of this analysis (typically on another thread) into this one.
      void combine(const Analysis* other)
        const Analysis_ATLAS_13TeV_3b_36invfb* specificOther
          = dynamic_cast<const Analysis_ATLAS_13TeV_3b_36invfb*>(other);

        for (auto& pair : _counters) { pair.second += specificOther->; }

        if (NCUTS != specificOther->NCUTS) NCUTS = specificOther->NCUTS;
        for (size_t j=0; j<NCUTS; j++) {
          cutFlowVector[j] += specificOther->cutFlowVector[j];
          cutFlowVector_str[j] = specificOther->cutFlowVector_str[j];


      virtual void collect_results() {

        // Now fill a results object with the results for each SR
        // Only exclusion regions here
        add_result(SignalRegionData("SR-3b-meff1-A"), 4., {2.5, 1.0}));
        add_result(SignalRegionData("SR-3b-meff2-A"), 3., {2.0, 0.5}));
        add_result(SignalRegionData("SR-3b-meff3-A"), 0., {0.8, 0.5}));
        add_result(SignalRegionData("SR-4b-meff1-A"), 1., {0.43, 0.31}));
        add_result(SignalRegionData("SR-4b-meff1-B"), 2., {2.6, 0.9}));
        add_result(SignalRegionData("SR-4b-meff2-A"), 1., {0.43, 0.27}));
        add_result(SignalRegionData("SR-4b-meff2-B"), 0., {1.3, 0.6}));


      void analysis_specific_reset() {
        // Clear signal regions
        for (auto& pair : _counters) { pair.second.reset(); }

        // Clear cut flow vector
        std::fill(cutFlowVector.begin(), cutFlowVector.end(), 0);



    // Class for collecting results for discovery regions as a derived class

    class Analysis_ATLAS_13TeV_3b_discoverySR_36invfb : public Analysis_ATLAS_13TeV_3b_36invfb {

      Analysis_ATLAS_13TeV_3b_discoverySR_36invfb() {

      virtual void collect_results() {

        add_result(SignalRegionData("SR-4b-meff1-A-disc"), 2., {0.7, 0.5}));


    // Factory fn


