refactor: untangle lnd's macaroon encryption with AESCipher class (#3152)
This commit is contained in:
parent
7bea591879
commit
3b350858c7
7 changed files with 207 additions and 124 deletions
|
|
@ -51,7 +51,7 @@ You can also use an AES-encrypted macaroon (more info) instead by using
|
|||
|
||||
- `LND_GRPC_MACAROON_ENCRYPTED`: eNcRyPtEdMaCaRoOn
|
||||
|
||||
To encrypt your macaroon, run `poetry run python lnbits/wallets/macaroon/macaroon.py`.
|
||||
To encrypt your macaroon, run `poetry run lnbits-cli encrypt macaroon`.
|
||||
|
||||
### LNbits
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import importlib
|
|||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
from getpass import getpass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
|
@ -40,7 +41,9 @@ from lnbits.core.views.extension_api import (
|
|||
api_uninstall_extension,
|
||||
)
|
||||
from lnbits.settings import settings
|
||||
from lnbits.utils.crypto import AESCipher
|
||||
from lnbits.wallets.base import Wallet
|
||||
from lnbits.wallets.macaroon import load_macaroon
|
||||
|
||||
|
||||
def coro(f):
|
||||
|
|
@ -79,6 +82,20 @@ def extensions():
|
|||
"""
|
||||
|
||||
|
||||
@lnbits_cli.group()
|
||||
def encrypt():
|
||||
"""
|
||||
Encryption commands
|
||||
"""
|
||||
|
||||
|
||||
@lnbits_cli.group()
|
||||
def decrypt():
|
||||
"""
|
||||
Decryption commands
|
||||
"""
|
||||
|
||||
|
||||
def get_super_user() -> Optional[str]:
|
||||
"""Get the superuser"""
|
||||
superuser_file = Path(settings.lnbits_data_folder, ".super_user")
|
||||
|
|
@ -479,6 +496,56 @@ async def extensions_uninstall(
|
|||
return False, str(ex)
|
||||
|
||||
|
||||
@encrypt.command("macaroon")
|
||||
def encrypt_macaroon():
|
||||
"""Encrypts a macaroon (LND wallets)"""
|
||||
_macaroon = getpass("Enter macaroon: ")
|
||||
try:
|
||||
macaroon = load_macaroon(_macaroon)
|
||||
except Exception as ex:
|
||||
click.echo(f"Error loading macaroon: {ex}")
|
||||
return
|
||||
key = getpass("Enter encryption key: ")
|
||||
aes = AESCipher(key.encode())
|
||||
try:
|
||||
encrypted_macaroon = aes.encrypt(bytes.fromhex(macaroon))
|
||||
except Exception as ex:
|
||||
click.echo(f"Error encrypting macaroon: {ex}")
|
||||
return
|
||||
click.echo("Encrypted macaroon: ")
|
||||
click.echo(encrypted_macaroon)
|
||||
|
||||
|
||||
@encrypt.command("aes")
|
||||
@click.option("-p", "--payload", required=True, help="Payload to encrypt.")
|
||||
def encrypt_aes(payload: str):
|
||||
"""AES encrypts a payload"""
|
||||
key = getpass("Enter encryption key: ")
|
||||
aes = AESCipher(key.encode())
|
||||
try:
|
||||
encrypted = aes.encrypt(payload.encode())
|
||||
except Exception as ex:
|
||||
click.echo(f"Error encrypting payload: {ex}")
|
||||
return
|
||||
click.echo("Encrypted payload: ")
|
||||
click.echo(encrypted)
|
||||
|
||||
|
||||
@decrypt.command("aes")
|
||||
@click.option("-p", "--payload", required=True, help="Payload to decrypt.")
|
||||
def decrypt_aes(payload: str):
|
||||
"""AES decrypts a payload"""
|
||||
key = getpass("Enter encryption key: ")
|
||||
aes = AESCipher(key.encode())
|
||||
try:
|
||||
decrypted = aes.decrypt(payload)
|
||||
except Exception as ex:
|
||||
click.echo(f"Error decrypting payload: {ex}")
|
||||
return
|
||||
click.echo("Decrypted payload: ")
|
||||
click.echo(decrypted)
|
||||
|
||||
|
||||
def main():
|
||||
"""main function"""
|
||||
lnbits_cli()
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
import base64
|
||||
import getpass
|
||||
from base64 import b64decode, b64encode, urlsafe_b64decode, urlsafe_b64encode
|
||||
from hashlib import md5, pbkdf2_hmac, sha256
|
||||
from typing import Union
|
||||
|
||||
from Cryptodome import Random
|
||||
from Cryptodome.Cipher import AES
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def random_secret_and_hash() -> tuple[str, str]:
|
||||
secret = Random.new().read(32)
|
||||
def random_secret_and_hash(length: int = 32) -> tuple[str, str]:
|
||||
secret = Random.new().read(length)
|
||||
return secret.hex(), sha256(secret).hexdigest()
|
||||
|
||||
|
||||
|
|
@ -30,73 +28,84 @@ def verify_preimage(preimage: str, payment_hash: str) -> bool:
|
|||
|
||||
|
||||
class AESCipher:
|
||||
"""This class is compatible with crypto-js/aes.js
|
||||
"""
|
||||
AES-256-CBC encryption/decryption with salt and base64 encoding.
|
||||
:param key: The key to use for en-/decryption. It can be bytes, a hex or a string.
|
||||
|
||||
This class is compatible with crypto-js/aes.js
|
||||
Encrypt and decrypt in Javascript using:
|
||||
import AES from "crypto-js/aes.js";
|
||||
import Utf8 from "crypto-js/enc-utf8.js";
|
||||
AES.encrypt(decrypted, password).toString()
|
||||
AES.decrypt(encrypted, password).toString(Utf8);
|
||||
|
||||
import AES from "crypto-js/aes.js";
|
||||
import Utf8 from "crypto-js/enc-utf8.js";
|
||||
AES.encrypt(decrypted, password).toString()
|
||||
AES.decrypt(encrypted, password).toString(Utf8);
|
||||
"""
|
||||
|
||||
def __init__(self, key=None, description=""):
|
||||
self.key = key
|
||||
self.description = description + " "
|
||||
def __init__(self, key: Union[bytes, str], block_size: int = 16):
|
||||
self.block_size = block_size
|
||||
if isinstance(key, bytes):
|
||||
self.key = key
|
||||
return
|
||||
try:
|
||||
self.key = bytes.fromhex(key)
|
||||
except ValueError:
|
||||
pass
|
||||
self.key = key.encode()
|
||||
|
||||
def pad(self, data):
|
||||
length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
|
||||
def pad(self, data: bytes) -> bytes:
|
||||
length = self.block_size - (len(data) % self.block_size)
|
||||
return data + (chr(length) * length).encode()
|
||||
|
||||
def unpad(self, data):
|
||||
return data[: -(data[-1] if isinstance(data[-1], int) else ord(data[-1]))]
|
||||
def unpad(self, data: bytes) -> bytes:
|
||||
_last = data[-1]
|
||||
if isinstance(_last, int):
|
||||
return data[:-_last]
|
||||
return data[: -ord(_last)]
|
||||
|
||||
@property
|
||||
def passphrase(self):
|
||||
passphrase = self.key if self.key is not None else None
|
||||
if passphrase is None:
|
||||
passphrase = getpass.getpass(f"Enter {self.description}password:")
|
||||
return passphrase
|
||||
|
||||
def bytes_to_key(self, data, salt, output=48):
|
||||
def derive_iv_and_key(
|
||||
self, salt: bytes, output_len: int = 32 + 16
|
||||
) -> tuple[bytes, bytes]:
|
||||
# extended from https://gist.github.com/gsakkis/4546068
|
||||
assert len(salt) == 8, len(salt)
|
||||
data += salt
|
||||
assert len(salt) == 8, "Salt must be 8 bytes"
|
||||
data = self.key + salt
|
||||
key = md5(data).digest()
|
||||
final_key = key
|
||||
while len(final_key) < output:
|
||||
while len(final_key) < output_len:
|
||||
key = md5(key + data).digest()
|
||||
final_key += key
|
||||
return final_key[:output]
|
||||
iv_key = final_key[:output_len]
|
||||
return iv_key[32:], iv_key[:32]
|
||||
|
||||
def decrypt(self, encrypted: str, urlsafe: bool = False) -> str:
|
||||
"""Decrypts a string using AES-256-CBC."""
|
||||
passphrase = self.passphrase
|
||||
|
||||
"""Decrypts a salted base64 encoded string using AES-256-CBC."""
|
||||
if urlsafe:
|
||||
encrypted_bytes = base64.urlsafe_b64decode(encrypted)
|
||||
decoded = urlsafe_b64decode(encrypted)
|
||||
else:
|
||||
encrypted_bytes = base64.b64decode(encrypted)
|
||||
decoded = b64decode(encrypted)
|
||||
|
||||
assert encrypted_bytes[0:8] == b"Salted__"
|
||||
salt = encrypted_bytes[8:16]
|
||||
key_iv = self.bytes_to_key(passphrase.encode(), salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
if decoded[0:8] != b"Salted__":
|
||||
raise ValueError("Invalid salt.")
|
||||
|
||||
salt = decoded[8:16]
|
||||
encrypted_bytes = decoded[16:]
|
||||
|
||||
iv, key = self.derive_iv_and_key(salt, 32 + 16)
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
|
||||
try:
|
||||
return self.unpad(aes.decrypt(encrypted_bytes[16:])).decode()
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("Wrong passphrase") from exc
|
||||
decrypted_bytes = aes.decrypt(encrypted_bytes)
|
||||
return self.unpad(decrypted_bytes).decode()
|
||||
except Exception as exc:
|
||||
raise ValueError("Decryption error") from exc
|
||||
|
||||
def encrypt(self, message: bytes, urlsafe: bool = False) -> str:
|
||||
passphrase = self.passphrase
|
||||
"""
|
||||
Encrypts a string using AES-256-CBC and returns a salted base64 encoded string.
|
||||
"""
|
||||
salt = Random.new().read(8)
|
||||
key_iv = self.bytes_to_key(passphrase.encode(), salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
iv, key = self.derive_iv_and_key(salt, 32 + 16)
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
encoded = b"Salted__" + salt + aes.encrypt(self.pad(message))
|
||||
return (
|
||||
base64.urlsafe_b64encode(encoded) if urlsafe else base64.b64encode(encoded)
|
||||
).decode()
|
||||
msg = self.pad(message)
|
||||
encrypted = aes.encrypt(msg)
|
||||
salted = b"Salted__" + salt + encrypted
|
||||
encoded = urlsafe_b64encode(salted) if urlsafe else b64encode(salted)
|
||||
return encoded.decode()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import lnbits.wallets.lnd_grpc_files.lightning_pb2_grpc as lnrpc
|
|||
import lnbits.wallets.lnd_grpc_files.router_pb2 as router
|
||||
import lnbits.wallets.lnd_grpc_files.router_pb2_grpc as routerrpc
|
||||
from lnbits.settings import settings
|
||||
from lnbits.utils.crypto import AESCipher, random_secret_and_hash
|
||||
from lnbits.utils.crypto import random_secret_and_hash
|
||||
|
||||
from .base import (
|
||||
InvoiceResponse,
|
||||
|
|
@ -72,6 +72,11 @@ class LndWallet(Wallet):
|
|||
"cannot initialize LndWallet: missing lnd_grpc_cert or lnd_cert"
|
||||
)
|
||||
|
||||
self.endpoint = self.normalize_endpoint(
|
||||
settings.lnd_grpc_endpoint, add_proto=False
|
||||
)
|
||||
self.port = int(settings.lnd_grpc_port)
|
||||
|
||||
macaroon = (
|
||||
settings.lnd_grpc_macaroon
|
||||
or settings.lnd_grpc_admin_macaroon
|
||||
|
|
@ -80,23 +85,11 @@ class LndWallet(Wallet):
|
|||
or settings.lnd_invoice_macaroon
|
||||
)
|
||||
encrypted_macaroon = settings.lnd_grpc_macaroon_encrypted
|
||||
if encrypted_macaroon:
|
||||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||
encrypted_macaroon
|
||||
)
|
||||
if not macaroon:
|
||||
raise ValueError(
|
||||
"cannot initialize LndWallet: "
|
||||
"missing lnd_grpc_macaroon or lnd_grpc_admin_macaroon or "
|
||||
"lnd_admin_macaroon or lnd_grpc_invoice_macaroon or "
|
||||
"lnd_invoice_macaroon or lnd_grpc_macaroon_encrypted"
|
||||
)
|
||||
try:
|
||||
self.macaroon = load_macaroon(macaroon, encrypted_macaroon)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"cannot load macaroon for LndWallet: {exc!s}") from exc
|
||||
|
||||
self.endpoint = self.normalize_endpoint(
|
||||
settings.lnd_grpc_endpoint, add_proto=False
|
||||
)
|
||||
self.port = int(settings.lnd_grpc_port)
|
||||
self.macaroon = load_macaroon(macaroon)
|
||||
cert = open(cert_path, "rb").read()
|
||||
creds = grpc.ssl_channel_credentials(cert)
|
||||
auth_creds = grpc.metadata_call_credentials(self.metadata_callback)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from loguru import logger
|
|||
|
||||
from lnbits.nodes.lndrest import LndRestNode
|
||||
from lnbits.settings import settings
|
||||
from lnbits.utils.crypto import AESCipher, random_secret_and_hash
|
||||
from lnbits.utils.crypto import random_secret_and_hash
|
||||
|
||||
from .base import (
|
||||
InvoiceResponse,
|
||||
|
|
@ -35,26 +35,6 @@ class LndRestWallet(Wallet):
|
|||
"cannot initialize LndRestWallet: missing lnd_rest_endpoint"
|
||||
)
|
||||
|
||||
macaroon = (
|
||||
settings.lnd_rest_macaroon
|
||||
or settings.lnd_admin_macaroon
|
||||
or settings.lnd_rest_admin_macaroon
|
||||
or settings.lnd_invoice_macaroon
|
||||
or settings.lnd_rest_invoice_macaroon
|
||||
)
|
||||
encrypted_macaroon = settings.lnd_rest_macaroon_encrypted
|
||||
if encrypted_macaroon:
|
||||
macaroon = AESCipher(description="macaroon decryption").decrypt(
|
||||
encrypted_macaroon
|
||||
)
|
||||
if not macaroon:
|
||||
raise ValueError(
|
||||
"cannot initialize LndRestWallet: "
|
||||
"missing lnd_rest_macaroon or lnd_admin_macaroon or "
|
||||
"lnd_rest_admin_macaroon or lnd_invoice_macaroon or "
|
||||
"lnd_rest_invoice_macaroon or lnd_rest_macaroon_encrypted"
|
||||
)
|
||||
|
||||
if not settings.lnd_rest_cert:
|
||||
logger.warning(
|
||||
"No certificate for LndRestWallet provided! "
|
||||
|
|
@ -68,7 +48,21 @@ class LndRestWallet(Wallet):
|
|||
# even on startup
|
||||
cert = settings.lnd_rest_cert or True
|
||||
|
||||
macaroon = load_macaroon(macaroon)
|
||||
macaroon = (
|
||||
settings.lnd_rest_macaroon
|
||||
or settings.lnd_admin_macaroon
|
||||
or settings.lnd_rest_admin_macaroon
|
||||
or settings.lnd_invoice_macaroon
|
||||
or settings.lnd_rest_invoice_macaroon
|
||||
)
|
||||
encrypted_macaroon = settings.lnd_rest_macaroon_encrypted
|
||||
try:
|
||||
macaroon = load_macaroon(macaroon, encrypted_macaroon)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"cannot load macaroon for LndRestWallet: {exc!s}"
|
||||
) from exc
|
||||
|
||||
headers = {
|
||||
"Grpc-Metadata-macaroon": macaroon,
|
||||
"User-Agent": settings.user_agent,
|
||||
|
|
|
|||
|
|
@ -1,45 +1,45 @@
|
|||
import base64
|
||||
|
||||
from loguru import logger
|
||||
from getpass import getpass
|
||||
from typing import Optional
|
||||
|
||||
from lnbits.utils.crypto import AESCipher
|
||||
|
||||
|
||||
def load_macaroon(macaroon: str) -> str:
|
||||
"""Returns hex version of a macaroon encoded in base64 or the file path.
|
||||
def load_macaroon(
|
||||
macaroon: Optional[str] = None,
|
||||
encrypted_macaroon: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Returns hex version of a macaroon encoded in base64 or the file path."""
|
||||
|
||||
:param macaroon: Macaroon encoded in base64 or file path.
|
||||
:type macaroon: str
|
||||
:return: Hex version of macaroon.
|
||||
:rtype: str
|
||||
"""
|
||||
if macaroon is None and encrypted_macaroon is None:
|
||||
raise ValueError("Either macaroon or encrypted_macaroon must be provided.")
|
||||
|
||||
if encrypted_macaroon:
|
||||
# if the macaroon is encrypted, decrypt it and return the hex version
|
||||
key = getpass("Enter the macaroon decryption key: ")
|
||||
aes = AESCipher(key.encode())
|
||||
return aes.decrypt(encrypted_macaroon)
|
||||
|
||||
assert macaroon, "macaroon must be set here"
|
||||
|
||||
# if the macaroon is a file path, load it and return hex version
|
||||
if macaroon.split(".")[-1] == "macaroon":
|
||||
with open(macaroon, "rb") as f:
|
||||
macaroon_bytes = f.read()
|
||||
return macaroon_bytes.hex()
|
||||
else:
|
||||
# if macaroon is a provided string
|
||||
# check if it is hex, if so, return
|
||||
try:
|
||||
bytes.fromhex(macaroon)
|
||||
return macaroon
|
||||
except ValueError:
|
||||
pass
|
||||
# convert the bas64 macaroon to hex
|
||||
try:
|
||||
macaroon = base64.b64decode(macaroon).hex()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if macaroon is a provided string check if it is hex, if so, return
|
||||
try:
|
||||
bytes.fromhex(macaroon)
|
||||
return macaroon
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# convert the base64 macaroon to hex
|
||||
try:
|
||||
macaroon = base64.b64decode(macaroon).hex()
|
||||
return macaroon
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return macaroon
|
||||
|
||||
|
||||
# todo: move to its own (crypto.py) file
|
||||
# if this file is executed directly, ask for a macaroon and encrypt it
|
||||
if __name__ == "__main__":
|
||||
macaroon = input("Enter macaroon: ")
|
||||
macaroon = load_macaroon(macaroon)
|
||||
macaroon = AESCipher(description="encryption").encrypt(macaroon.encode())
|
||||
logger.info("Encrypted macaroon:")
|
||||
logger.info(macaroon)
|
||||
|
|
|
|||
20
tests/unit/test_crypto_aes.py
Normal file
20
tests/unit/test_crypto_aes.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
|
||||
from lnbits.utils.crypto import AESCipher
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"key",
|
||||
[
|
||||
"normal_string",
|
||||
b"normal_bytes",
|
||||
b"hex_string".hex(),
|
||||
],
|
||||
)
|
||||
async def test_aes_encrypt_decrypt(key):
|
||||
aes = AESCipher(key)
|
||||
original_text = "Hello, World!"
|
||||
encrypted_text = aes.encrypt(original_text.encode())
|
||||
decrypted_text = aes.decrypt(encrypted_text)
|
||||
assert original_text == decrypted_text
|
||||
Loading…
Add table
Add a link
Reference in a new issue