# coding: utf-8

def get_rsa_public_key():
    '''
        @name 获取RSA公钥内容
        @author hwliang
        @return str
    '''
    from BTPanel import session
    pub_key = 'rsa_public_key'
    public_key = session.get(pub_key)
    if not public_key:
        create_rsa_key()
        public_key = session.get(pub_key)
    return public_key


def get_rsa_private_key():
    '''
        @name 获取RSA私钥内容
        @author hwliang
        @return str
    '''
    from BTPanel import session
    prv_key = 'rsa_private_key'
    private_key = session.get(prv_key)
    if not private_key:
        create_rsa_key()
        private_key = session.get(prv_key)
    return private_key


def create_rsa_key():
    '''
        @name 创建RSA密钥
        @author hwliang
        @return bool
    '''
    try:
        from BTPanel import session
        pub_key = 'rsa_public_key'
        prv_key = 'rsa_private_key'
        if pub_key in session and prv_key in session:
            return True
        try:
            from Crypto.PublicKey import RSA
            key = RSA.generate(1024)
            private_key = key.exportKey("PEM")
            public_key = key.publickey().exportKey("PEM")
        except:
            is_re_install = '{}/data/pycryptodome_re_install.pl'.format(get_panel_path())
            if not os.path.exists(is_re_install):
                os.system("nohup btpip install pycryptodome -I &> /dev/null &")
                writeFile(is_re_install, 'True')

            priv_pem = '/tmp/private.pem'
            pub_pem = '/tmp/public.pem'
            ExecShell("openssl genrsa -out {} 1024".format(priv_pem))
            ExecShell("openssl rsa -pubout -in {} -out {}".format(priv_pem, pub_pem))
            if not os.path.exists(priv_pem) or not os.path.exists(pub_pem):
                return False

            private_key = readFile(priv_pem, 'rb')
            public_key = readFile(pub_pem, 'rb')

            if os.path.exists(priv_pem): os.remove(priv_pem)
            if os.path.exists(pub_pem): os.remove(pub_pem)

        session[pub_key] = public_key.decode('utf-8').replace("\n", "")
        session[prv_key] = private_key.decode('utf-8')

        return True
    except:
        return False


def rsa_encrypt(data):
    '''
        @name RSA加密数据
        @param data str 要加密的数据
        @return str
    '''
    # 分片长度 1024 / 8 - 11 = 117
    split_length = 117
    try:
        from Crypto.PublicKey import RSA
        from Crypto.Cipher import PKCS1_v1_5 as Cipher_pkcs

        # 初始化RSA加密对象
        public_key = get_rsa_public_key()
        cipher_public = Cipher_pkcs.new(RSA.importKey(public_key))

        # 分片加密
        data = data.encode('utf-8')
        encrypted_arr = []
        for i in range(0, len(data), split_length):
            d = data[i:i + split_length]
            encrypted_data = cipher_public.encrypt(d)
            encrypted_base64 = base64.b64encode(encrypted_data).decode()
            encrypted_arr.append(encrypted_base64)

        # 用换行符拼接
        return "\n".join(encrypted_arr)
    except:
        return ''


def rsa_decrypt(data):
    '''
        @name RSA解密数据
        @param data str 要解密的数据
        @return str
    '''
    try:
        from Crypto.PublicKey import RSA
        from Crypto.Cipher import PKCS1_v1_5 as Cipher_pkcs

        # 初始化RSA解密对象
        private_key = get_rsa_private_key()
        cipher_private = Cipher_pkcs.new(RSA.importKey(private_key))

        # 分片解密
        decrypted_str = b""
        for d in data.split("\n"):
            if not d: continue
            res = base64.b64decode(d)
            if not res: continue
            decrypted_data = cipher_private.decrypt(res, None)
            decrypted_str += decrypted_data
        return decrypted_str.decode('utf-8')
    except:
        return ''


def rsa_encrypt_for_private_key(data):
    '''
        @name RSA私钥加密数据
        @author hwliang
        @param data str 要加密的数据
        @return str
    '''
    # 分片长度 1024 / 8 - 11 = 117
    split_length = 117
    try:
        from Crypto.PublicKey import RSA
        from Crypto.Cipher import PKCS1_v1_5 as Cipher_pkcs

        # 初始化RSA加密对象
        private_key = get_rsa_private_key()
        cipher_private = Cipher_pkcs.new(RSA.importKey(private_key))

        # 分片加密
        data = data.encode('utf-8')
        encrypted_arr = []
        for i in range(0, len(data), split_length):
            d = data[i:i + split_length]
            encrypted_data = cipher_private.encrypt(d)
            encrypted_base64 = base64.b64encode(encrypted_data).decode()
            encrypted_arr.append(encrypted_base64)

        # 用换行符拼接
        return "\n".join(encrypted_arr)
    except:
        return ''