test: unit tests for lndrpc (#2442)

This commit is contained in:
Vlad Stan 2024-04-19 14:21:21 +03:00 committed by GitHub
parent 4f118c5f98
commit 67fdb77339
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1082 additions and 112 deletions

View file

@ -116,7 +116,7 @@ class LndWallet(Wallet):
try: try:
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest()) resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
except Exception as exc: except Exception as exc:
return StatusResponse(str(exc), 0) return StatusResponse(f"Unable to connect, got: '{exc}'", 0)
return StatusResponse(None, resp.balance * 1000) return StatusResponse(None, resp.balance * 1000)
@ -147,6 +147,7 @@ class LndWallet(Wallet):
req = ln.Invoice(**data) req = ln.Invoice(**data)
resp = await self.rpc.AddInvoice(req) resp = await self.rpc.AddInvoice(req)
except Exception as exc: except Exception as exc:
logger.warning(exc)
error_message = str(exc) error_message = str(exc)
return InvoiceResponse(False, None, None, error_message) return InvoiceResponse(False, None, None, error_message)
@ -165,6 +166,7 @@ class LndWallet(Wallet):
try: try:
resp = await self.routerpc.SendPaymentV2(req).read() resp = await self.routerpc.SendPaymentV2(req).read()
except Exception as exc: except Exception as exc:
logger.warning(exc)
return PaymentResponse(False, None, None, None, str(exc)) return PaymentResponse(False, None, None, None, str(exc))
# PaymentStatus from https://github.com/lightningnetwork/lnd/blob/master/channeldb/payments.go#L178 # PaymentStatus from https://github.com/lightningnetwork/lnd/blob/master/channeldb/payments.go#L178
@ -176,12 +178,12 @@ class LndWallet(Wallet):
} }
failure_reasons = { failure_reasons = {
0: "No error given.", 0: "Payment failed: No error given.",
1: "Payment timed out.", 1: "Payment failed: Payment timed out.",
2: "No route to destination.", 2: "Payment failed: No route to destination.",
3: "Error.", 3: "Payment failed: Error.",
4: "Incorrect payment details.", 4: "Payment failed: Incorrect payment details.",
5: "Insufficient balance.", 5: "Payment failed: Insufficient balance.",
} }
fee_msat = None fee_msat = None
@ -204,19 +206,23 @@ class LndWallet(Wallet):
try: try:
r_hash = hex_to_bytes(checking_id) r_hash = hex_to_bytes(checking_id)
if len(r_hash) != 32: if len(r_hash) != 32:
raise ValueError
except ValueError:
# this may happen if we switch between backend wallets # this may happen if we switch between backend wallets
# that use different checking_id formats # that use different checking_id formats
return PaymentPendingStatus() raise ValueError
try:
resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash)) resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash))
except grpc.RpcError:
return PaymentPendingStatus() # todo: where is the FAILED status
if resp.settled: if resp.settled:
return PaymentSuccessStatus() return PaymentSuccessStatus()
return PaymentPendingStatus() return PaymentPendingStatus()
except grpc.RpcError as exc:
logger.warning(exc)
return PaymentPendingStatus()
except Exception as exc:
logger.warning(exc)
return PaymentPendingStatus()
async def get_payment_status(self, checking_id: str) -> PaymentStatus: async def get_payment_status(self, checking_id: str) -> PaymentStatus:
""" """
@ -231,10 +237,6 @@ class LndWallet(Wallet):
# that use different checking_id formats # that use different checking_id formats
return PaymentPendingStatus() return PaymentPendingStatus()
resp = self.routerpc.TrackPaymentV2(
router.TrackPaymentRequest(payment_hash=r_hash)
)
# # HTLCAttempt.HTLCStatus: # # HTLCAttempt.HTLCStatus:
# # https://github.com/lightningnetwork/lnd/blob/master/lnrpc/lightning.proto#L3641 # # https://github.com/lightningnetwork/lnd/blob/master/lnrpc/lightning.proto#L3641
# htlc_statuses = { # htlc_statuses = {
@ -250,6 +252,9 @@ class LndWallet(Wallet):
} }
try: try:
resp = self.routerpc.TrackPaymentV2(
router.TrackPaymentRequest(payment_hash=r_hash)
)
async for payment in resp: async for payment in resp:
if len(payment.htlcs) and statuses[payment.status]: if len(payment.htlcs) and statuses[payment.status]:
return PaymentSuccessStatus( return PaymentSuccessStatus(

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,6 @@ class FundingSourceConfig(BaseModel):
name: str name: str
skip: Optional[bool] skip: Optional[bool]
wallet_class: str wallet_class: str
client_field: Optional[str]
settings: dict settings: dict
@ -28,12 +27,16 @@ class TestMock(BaseModel):
class Mock(FunctionMock, TestMock): class Mock(FunctionMock, TestMock):
name: str
@staticmethod @staticmethod
def combine_mocks(fs_mock, test_mock): def combine_mocks(mock_name, fs_mock, test_mock):
_mock = fs_mock | test_mock _mock = fs_mock | test_mock
if "response" in _mock and "response" in fs_mock: if "response" in _mock and "response" in fs_mock:
_mock["response"] |= fs_mock["response"] _mock["response"] |= fs_mock["response"]
return Mock(**_mock) m = Mock(name=mock_name, **_mock)
return m
class FunctionMocks(BaseModel): class FunctionMocks(BaseModel):
@ -93,35 +96,58 @@ class WalletTest(BaseModel):
return [t] return [t]
def _tests_from_fs_mocks(self, fn, test, fs_name: str) -> List["WalletTest"]: def _tests_from_fs_mocks(self, fn, test, fs_name: str) -> List["WalletTest"]:
tests: List[WalletTest] = []
fs_mocks = fn["mocks"][fs_name] fs_mocks = fn["mocks"][fs_name]
test_mocks = test["mocks"][fs_name] test_mocks = test["mocks"][fs_name]
for mock_name in fs_mocks: mocks = self._build_mock_objects(list(fs_mocks), fs_mocks, test_mocks)
tests += self._tests_from_mocks(fs_mocks[mock_name], test_mocks[mock_name])
return tests
def _tests_from_mocks(self, fs_mock, test_mocks) -> List["WalletTest"]: return [self._tests_from_mock(m) for m in mocks]
tests: List[WalletTest] = []
for test_mock in test_mocks:
# different mocks that result in the same
# return value for the tested function
unique_test = self._test_from_mocks(fs_mock, test_mock)
tests.append(unique_test) def _build_mock_objects(self, mock_names, fs_mocks, test_mocks):
return tests mocks = []
def _test_from_mocks(self, fs_mock, test_mock) -> "WalletTest": for mock_name in mock_names:
mock = Mock.combine_mocks(fs_mock, test_mock) if mock_name not in test_mocks:
continue
for test_mock in test_mocks[mock_name]:
mock = {"fs_mock": fs_mocks[mock_name], "test_mock": test_mock}
if len(mock_names) == 1:
mocks.append({mock_name: mock})
else:
sub_mocks = self._build_mock_objects(
mock_names[1:], fs_mocks, test_mocks
)
for sub_mock in sub_mocks:
mocks.append({mock_name: mock} | sub_mock)
return mocks
return mocks
def _tests_from_mock(self, mock_obj) -> "WalletTest":
test_mocks: List[Mock] = [
Mock.combine_mocks(
mock_name,
mock_obj[mock_name]["fs_mock"],
mock_obj[mock_name]["test_mock"],
)
for mock_name in mock_obj
]
any_mock_skipped = len([m for m in test_mocks if m.skip])
extra_description = ";".join(
[m.description for m in test_mocks if m.description]
)
return WalletTest( return WalletTest(
**( **(
self.dict() self.dict()
| { | {
"description": f"""{self.description}:{mock.description or ""}""", "description": f"{self.description}:{extra_description}",
"mocks": [*self.mocks, mock], "mocks": test_mocks,
"skip": self.skip or mock.skip, "skip": self.skip or any_mock_skipped,
} }
) )
) )
@ -131,3 +157,12 @@ class DataObject:
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k in kwargs: for k in kwargs:
setattr(self, k, kwargs[k]) setattr(self, k, kwargs[k])
def __str__(self):
data = []
for k in self.__dict__:
value = getattr(self, k)
if isinstance(value, list):
value = [f"{k}={v}" for v in value]
data.append(f"{k}={value}")
return ";".join(data)

View file

@ -55,7 +55,7 @@ def _tests_for_funding_source(
def build_test_id(test: WalletTest): def build_test_id(test: WalletTest):
return f"{test.funding_source}.{test.function}({test.description})" return f"{test.funding_source.name}.{test.function}({test.description})"
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet: def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
@ -83,7 +83,13 @@ async def check_assertions(wallet, _test_data: WalletTest):
call_params = _test_data.call_params call_params = _test_data.call_params
if "expect" in test_data: if "expect" in test_data:
await _assert_data(wallet, tested_func, call_params, _test_data.expect) await _assert_data(
wallet,
tested_func,
call_params,
_test_data.expect,
_test_data.description,
)
# if len(_test_data.mocks) == 0: # if len(_test_data.mocks) == 0:
# # all calls should fail after this method is called # # all calls should fail after this method is called
# await wallet.cleanup() # await wallet.cleanup()
@ -96,14 +102,25 @@ async def check_assertions(wallet, _test_data: WalletTest):
raise AssertionError("Expected outcome not specified") raise AssertionError("Expected outcome not specified")
async def _assert_data(wallet, tested_func, call_params, expect): async def _assert_data(wallet, tested_func, call_params, expect, description):
resp = await getattr(wallet, tested_func)(**call_params) resp = await getattr(wallet, tested_func)(**call_params)
fn_prefix = "__eval__:"
for key in expect: for key in expect:
received = getattr(resp, key)
expected = expect[key] expected = expect[key]
assert ( if key.startswith(fn_prefix):
getattr(resp, key) == expect[key] key = key[len(fn_prefix) :]
), f"""Field "{key}". Received: "{received}". Expected: "{expected}".""" received = getattr(resp, key)
expected = expected.format(**{key: received, "description": description})
_assert = eval(expected)
else:
received = getattr(resp, key)
_assert = getattr(resp, key) == expect[key]
assert _assert, (
f""" Field "{key}"."""
f""" Received: "{received}"."""
f""" Expected: "{expected}"."""
)
async def _assert_error(wallet, tested_func, call_params, expect_error): async def _assert_error(wallet, tested_func, call_params, expect_error):

View file

@ -1,6 +1,6 @@
import importlib import importlib
from typing import Dict, List, Optional from typing import Dict, List, Optional
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import pytest import pytest
from pytest_mock.plugin import MockerFixture from pytest_mock.plugin import MockerFixture
@ -46,7 +46,12 @@ def _apply_rpc_mock(mocker: MockerFixture, mock: RpcMock):
value = mock.response[field_name] value = mock.response[field_name]
values = value if isinstance(value, list) else [value] values = value if isinstance(value, list) else [value]
return_value[field_name] = Mock(side_effect=[_mock_field(f) for f in values]) _mock_class = (
AsyncMock if values[0]["request_type"] == "async-function" else Mock
)
return_value[field_name] = _mock_class(
side_effect=[_mock_field(f) for f in values]
)
m = _data_mock(return_value) m = _data_mock(return_value)
assert mock.method, "Missing method for RPC mock." assert mock.method, "Missing method for RPC mock."
@ -59,7 +64,8 @@ def _check_calls(expected_calls):
for func_call in func_calls: for func_call in func_calls:
req = func_call["request_data"] req = func_call["request_data"]
args = req["args"] if "args" in req else {} args = req["args"] if "args" in req else {}
kwargs = req["kwargs"] if "kwargs" in req else {} kwargs = _eval_dict(req["kwargs"]) if "kwargs" in req else {}
if "klass" in req: if "klass" in req:
*rest, cls = req["klass"].split(".") *rest, cls = req["klass"].split(".")
req_module = importlib.import_module(".".join(rest)) req_module = importlib.import_module(".".join(rest))
@ -70,12 +76,9 @@ def _check_calls(expected_calls):
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet): def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
assert (
test_data.funding_source.client_field
), f"Missing client field for wallet {wallet}"
client_field = getattr(wallet, test_data.funding_source.client_field)
expected_calls: Dict[str, List] = {} expected_calls: Dict[str, List] = {}
for mock in test_data.mocks: for mock in test_data.mocks:
client_field = getattr(wallet, mock.name)
spy = _spy_mock(mocker, mock, client_field) spy = _spy_mock(mocker, mock, client_field)
expected_calls |= spy expected_calls |= spy
@ -83,6 +86,7 @@ def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet)
def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field): def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
expected_calls: Dict[str, List] = {} expected_calls: Dict[str, List] = {}
assert isinstance(mock.response, dict), "Expected data RPC response" assert isinstance(mock.response, dict), "Expected data RPC response"
for field_name in mock.response: for field_name in mock.response:
@ -95,37 +99,95 @@ def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
"request_data": f["request_data"], "request_data": f["request_data"],
} }
for f in values for f in values
if f["request_type"] == "function" and "request_data" in f if (
f["request_type"] == "function" or f["request_type"] == "async-function"
)
and "request_data" in f
] ]
return expected_calls return expected_calls
def _async_generator(data):
async def f1():
for d in data:
value = _eval_dict(d)
yield _dict_to_object(value)
return f1()
def _mock_field(field): def _mock_field(field):
response_type = field["response_type"] response_type = field["response_type"]
request_type = field["request_type"] request_type = field["request_type"]
response = field["response"] response = _eval_dict(field["response"])
if request_type == "data": if request_type == "data":
return _dict_to_object(response) return _dict_to_object(response)
if request_type == "function": if request_type == "function" or request_type == "async-function":
if response_type == "data": if response_type == "data":
return _dict_to_object(response) return _dict_to_object(response)
if response_type == "exception": if response_type == "exception":
return _raise(response) return _raise(response)
if response_type == "__aiter__":
# todo: support dict
return _async_generator(field["response"])
if response_type == "function" or response_type == "async-function":
return_value = {}
for field_name in field["response"]:
value = field["response"][field_name]
_mock_class = (
AsyncMock if value["request_type"] == "async-function" else Mock
)
return_value[field_name] = _mock_class(side_effect=[_mock_field(value)])
return _dict_to_object(return_value)
return response return response
def _eval_dict(data: Optional[dict]) -> Optional[dict]:
fn_prefix = "__eval__:"
if not data:
return data
# if isinstance(data, list):
# return [_eval_dict(i) for i in data]
if not isinstance(data, dict):
return data
d = {}
for k in data:
if k.startswith(fn_prefix):
field = k[len(fn_prefix) :]
d[field] = eval(data[k])
elif isinstance(data[k], dict):
d[k] = _eval_dict(data[k])
elif isinstance(data[k], list):
d[k] = [_eval_dict(i) for i in data[k]]
else:
d[k] = data[k]
return d
def _dict_to_object(data: Optional[dict]) -> Optional[DataObject]: def _dict_to_object(data: Optional[dict]) -> Optional[DataObject]:
if not data: if not data:
return None return None
# if isinstance(data, list):
# return [_dict_to_object(i) for i in data]
if not isinstance(data, dict):
return data
d = {**data} d = {**data}
for k in data: for k in data:
value = data[k] value = data[k]
if isinstance(value, dict): if isinstance(value, dict):
d[k] = _dict_to_object(value) d[k] = _dict_to_object(value)
elif isinstance(value, list):
d[k] = [_dict_to_object(v) for v in value]
return DataObject(**d) return DataObject(**d)
@ -134,7 +196,9 @@ def _data_mock(data: dict) -> Mock:
return Mock(return_value=_dict_to_object(data)) return Mock(return_value=_dict_to_object(data))
def _raise(error: dict): def _raise(error: Optional[dict]):
if not error:
return Exception()
data = error["data"] if "data" in error else None data = error["data"] if "data" in error else None
if "module" not in error or "class" not in error: if "module" not in error or "class" not in error:
return Exception(data) return Exception(data)